diff --git a/StableSR/.gitignore b/StableSR/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..f8eaf0b18fd9c5bf978821d9401eb66bc2640f18
--- /dev/null
+++ b/StableSR/.gitignore
@@ -0,0 +1,134 @@
+# ignored folders
+logs/*
+models/*
+src/
+results/
+wandb/
+output/
+
+*.DS_Store
+.idea
+
+# ignored files
+version.py
+
+# ignored files with suffix
+*.html
+*.png
+*.jpeg
+*.jpg
+*.gif
+*.pth
+*.zip
+# *.txt
+*.svg
+*.ckpt
+
+# template
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+outputs/
diff --git a/StableSR/LICENSE.txt b/StableSR/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..44bf750a27c1c2439a418a71c94925db83ad9d37
--- /dev/null
+++ b/StableSR/LICENSE.txt
@@ -0,0 +1,35 @@
+S-Lab License 1.0
+
+Copyright 2022 S-Lab
+
+Redistribution and use for non-commercial purpose in source and
+binary forms, with or without modification, are permitted provided
+that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+In the event that redistribution and/or use for commercial purpose in
+source or binary forms, with or without modification is required,
+please contact the contributor(s) of the work.
\ No newline at end of file
diff --git a/StableSR/README.md b/StableSR/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7aa501c6041ce19f4fb284c1c2e5254e4676985b
--- /dev/null
+++ b/StableSR/README.md
@@ -0,0 +1,175 @@
+
+
+
+
+## Exploiting Diffusion Prior for Real-World Image Super-Resolution
+
+[Paper](https://arxiv.org/abs/2305.07015) | [Project Page](https://iceclear.github.io/projects/stablesr/) | [Video](https://www.youtube.com/watch?v=5MZy9Uhpkw4) | [WebUI](https://github.com/pkuliyi2015/sd-webui-stablesr) | [ModelScope](https://modelscope.cn/models/xhlin129/cv_stablesr_image-super-resolution/summary)
+
+
+
[![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/cjwbw/stablesr) ![visitors](https://visitor-badge.laobi.icu/badge?page_id=IceClear/StableSR)
+
+
+[Jianyi Wang](https://iceclear.github.io/), [Zongsheng Yue](https://zsyoaoa.github.io/), [Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
+
+S-Lab, Nanyang Technological University
+
+
+
+:star: If StableSR is helpful to your images or projects, please help star this repo. Thanks! :hugs:
+
+### Update
+- **2023.07.31**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/cjwbw/stablesr) Thank [Chenxi](https://github.com/chenxwh) for the implementation!
+- **2023.07.16**: You may reproduce the LDM baseline used in our paper using [LDM-SRtuning](https://github.com/IceClear/LDM-SRtuning) [![GitHub Stars](https://img.shields.io/github/stars/IceClear/LDM-SRtuning?style=social)](https://github.com/IceClear/LDM-SRtuning).
+- **2023.07.14**: :whale: [**ModelScope**](https://modelscope.cn/models/xhlin129/cv_stablesr_image-super-resolution/summary) for StableSR is released!
+- **2023.06.30**: :whale: [**New model**](https://huggingface.co/Iceclear/StableSR/blob/main/stablesr_768v_000139.ckpt) trained on [SD-2.1-768v](https://huggingface.co/stabilityai/stable-diffusion-2-1) is released! Better performance with fewer artifacts!
+- **2023.06.28**: Support training on SD-2.1-768v.
+- **2023.05.22**: :whale: Improve the code to save more GPU memory, now 128 --> 512 needs 8.9G. Enable start from intermediate steps.
+- **2023.05.20**: :whale: The [**WebUI**](https://github.com/pkuliyi2015/sd-webui-stablesr) [![GitHub Stars](https://img.shields.io/github/stars/pkuliyi2015/sd-webui-stablesr?style=social)](https://github.com/pkuliyi2015/sd-webui-stablesr) of StableSR is available. Thank [Li Yi](https://github.com/pkuliyi2015) for the implementation!
+- **2023.05.13**: Add Colab demo of StableSR.
+- **2023.05.11**: Repo is released.
+
+### TODO
+- [ ] HuggingFace demo (If necessary)
+- [x] ~~Code release~~
+- [x] ~~Update link to paper and project page~~
+- [x] ~~Pretrained models~~
+- [x] ~~Colab demo~~
+- [x] ~~StableSR-768v released~~
+- [x] ~~Replicate demo~~
+
+### Demo on real-world SR
+
+[
](https://imgsli.com/MTc2MTI2) [
](https://imgsli.com/MTc2MTE2) [
](https://imgsli.com/MTc2MTIw)
+[
](https://imgsli.com/MTc2MjUy) [
](https://imgsli.com/MTc2MTMy) [
](https://imgsli.com/MTc2MTMz)
+[
](https://imgsli.com/MTc2MjQ5) [
](https://imgsli.com/MTc2MTM0) [
](https://imgsli.com/MTc2MTM2) [
](https://imgsli.com/MTc2MjU0)
+
+For more evaluation, please refer to our [paper](https://arxiv.org/abs/2305.07015) for details.
+
+### Demo on 4K Results
+
+- StableSR is capable of achieving arbitrary upscaling in theory, below is a 8x example with a result beyond 4K (5120x3680).
+The example image is taken from [here](https://github.com/Mikubill/sd-webui-controlnet/blob/main/tests/images/ski.jpg).
+
+[
](https://imgsli.com/MTc4NDk2)
+
+- We further directly test StableSR on AIGC and compared with several diffusion-based upscalers following the suggestions.
+A 4K demo is [here](https://imgsli.com/MTc4MDg3), which is a 4x SR on the image from [here](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111).
+More comparisons can be found [here](https://github.com/IceClear/StableSR/issues/2).
+
+### Dependencies and Installation
+- Pytorch == 1.12.1
+- CUDA == 11.7
+- pytorch-lightning==1.4.2
+- xformers == 0.0.16 (Optional)
+- Other required packages in `environment.yaml`
+```
+# git clone this repository
+git clone https://github.com/IceClear/StableSR.git
+cd StableSR
+
+# Create a conda environment and activate it
+conda env create --file environment.yaml
+conda activate stablesr
+
+# Install xformers
+conda install xformers -c xformers/label/dev
+
+# Install taming & clip
+pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
+pip install -e .
+```
+
+### Running Examples
+
+#### Train
+Download the pretrained Stable Diffusion models from [[HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base)]
+
+- Train Time-aware encoder with SFT: set the ckpt_path in config files ([Line 22](https://github.com/IceClear/StableSR/blob/main/configs/stableSRNew/v2-finetune_text_T_512.yaml#L22) and [Line 55](https://github.com/IceClear/StableSR/blob/main/configs/stableSRNew/v2-finetune_text_T_512.yaml#L55))
+```
+python main.py --train --base configs/stableSRNew/v2-finetune_text_T_512.yaml --gpus GPU_ID, --name NAME --scale_lr False
+```
+
+- Train CFW: set the ckpt_path in config files ([Line 6](https://github.com/IceClear/StableSR/blob/main/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml#L6)).
+
+You need to first generate training data using the finetuned diffusion model in the first stage. The data folder should be like this:
+```
+CFW_trainingdata/
+ └── inputs
+ └── 00000001.png # LQ images, (512, 512, 3) (resize to 512x512)
+ └── ...
+ └── gts
+ └── 00000001.png # GT images, (512, 512, 3) (512x512)
+ └── ...
+ └── latents
+ └── 00000001.npy # Latent codes (N, 4, 64, 64) of HR images generated by the diffusion U-net, saved in .npy format.
+ └── ...
+ └── samples
+ └── 00000001.png # The HR images generated from latent codes, just to make sure the generated latents are correct.
+ └── ...
+```
+
+Then you can train CFW:
+```
+python main.py --train --base configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml --gpus GPU_ID, --name NAME --scale_lr False
+```
+
+#### Resume
+
+```
+python main.py --train --base configs/stableSRNew/v2-finetune_text_T_512.yaml --gpus GPU_ID, --resume RESUME_PATH --scale_lr False
+```
+
+#### Test directly
+
+Download the Diffusion and autoencoder pretrained models from [[HuggingFace](https://huggingface.co/Iceclear/StableSR/blob/main/README.md) | [Google Drive](https://drive.google.com/drive/folders/1FBkW9FtTBssM_42kOycMPE0o9U5biYCl?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/jianyi001_e_ntu_edu_sg/Et5HPkgRyyxNk269f5xYCacBpZq-bggFRCDbL9imSQ5QDQ)].
+We use the same color correction scheme introduced in paper by default.
+You may change ```--colorfix_type wavelet``` for better color correction.
+You may also disable color correction by ```--colorfix_type nofix```
+
+- Test on 128 --> 512: You need at least 10G GPU memory to run this script (batchsize 2 by default)
+```
+python scripts/sr_val_ddpm_text_T_vqganfin_old.py --config configs/stableSRNew/v2-finetune_text_T_512.yaml --ckpt CKPT_PATH --vqgan_ckpt VQGANCKPT_PATH --init-img INPUT_PATH --outdir OUT_DIR --ddpm_steps 200 --dec_w 0.5 --colorfix_type adain
+```
+- Test on arbitrary size w/o chop for autoencoder (for results beyond 512): The memory cost depends on your image size, but is usually above 10G.
+```
+python scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas.py --config configs/stableSRNew/v2-finetune_text_T_512.yaml --ckpt CKPT_PATH --vqgan_ckpt VQGANCKPT_PATH --init-img INPUT_PATH --outdir OUT_DIR --ddpm_steps 200 --dec_w 0.5 --colorfix_type adain
+```
+
+- Test on arbitrary size w/ chop for autoencoder: Current default setting needs at least 18G to run, you may reduce the autoencoder tile size by setting ```--vqgantile_size``` and ```--vqgantile_stride```.
+Note the min tile size is 512 and the stride should be smaller than the tile size. A smaller size may introduce more border artifacts.
+```
+python scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas_tile.py --config configs/stableSRNew/v2-finetune_text_T_512.yaml --ckpt CKPT_PATH --vqgan_ckpt VQGANCKPT_PATH --init-img INPUT_PATH --outdir OUT_DIR --ddpm_steps 200 --dec_w 0.5 --colorfix_type adain
+```
+
+- For test on 768 model, you need to set ```--config configs/stableSRNew/v2-finetune_text_T_768v.yaml```, ```--input_size 768``` and ```--ckpt```. You can also adjust ```--tile_overlap```, ```--vqgantile_size``` and ```--vqgantile_stride``` accordingly. We did not finetune CFW.
+
+#### Test using Replicate API
+```
+import replicate
+model = replicate.models.get()
+model.predict(input_image=...)
+```
+You may see [here](https://replicate.com/cjwbw/stablesr/api) for more information.
+
+### Citation
+If our work is useful for your research, please consider citing:
+
+ @inproceedings{wang2023exploiting,
+ author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin CK and Loy, Chen Change},
+ title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution},
+ booktitle = {arXiv preprint arXiv:2305.07015},
+ year = {2023}
+ }
+
+### License
+
+This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license.
+
+### Acknowledgement
+
+This project is based on [stablediffusion](https://github.com/Stability-AI/stablediffusion), [latent-diffusion](https://github.com/CompVis/latent-diffusion), [SPADE](https://github.com/NVlabs/SPADE), [mixture-of-diffusers](https://github.com/albarji/mixture-of-diffusers) and [BasicSR](https://github.com/XPixelGroup/BasicSR). Thanks for their awesome work.
+
+### Contact
+If you have any questions, please feel free to reach me out at `iceclearwjy@gmail.com`.
diff --git a/StableSR/basicsr/__init__.py b/StableSR/basicsr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28437544a254656cca7fb7021ef7bbf724cf2879
--- /dev/null
+++ b/StableSR/basicsr/__init__.py
@@ -0,0 +1,12 @@
+# https://github.com/xinntao/BasicSR
+# flake8: noqa
+from .archs import *
+from .data import *
+from .losses import *
+from .metrics import *
+from .models import *
+from .ops import *
+from .test import *
+from .train import *
+from .utils import *
+# from .version import __gitsha__, __version__
diff --git a/StableSR/basicsr/archs/__init__.py b/StableSR/basicsr/archs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..af6bcbd97bb3e4914c3c91dc53e0708bcac66075
--- /dev/null
+++ b/StableSR/basicsr/archs/__init__.py
@@ -0,0 +1,24 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import ARCH_REGISTRY
+
+__all__ = ['build_network']
+
+# automatically scan and import arch modules for registry
+# scan all the files under the 'archs' folder and collect files ending with '_arch.py'
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
+# import all the arch modules
+_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
+
+
+def build_network(opt):
+ opt = deepcopy(opt)
+ network_type = opt.pop('type')
+ net = ARCH_REGISTRY.get(network_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
+ return net
diff --git a/StableSR/basicsr/archs/arch_util.py b/StableSR/basicsr/archs/arch_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f2af24b73c37d3da0664d33a313651be6e33e8f
--- /dev/null
+++ b/StableSR/basicsr/archs/arch_util.py
@@ -0,0 +1,352 @@
+import collections.abc
+import math
+import torch
+import torchvision
+import warnings
+from distutils.version import LooseVersion
+from itertools import repeat
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn import init as init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
+from basicsr.utils import get_root_logger
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+ """Initialize network weights.
+
+ Args:
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+ scale (float): Scale initialized weights, especially for residual
+ blocks. Default: 1.
+ bias_fill (float): The value to fill bias. Default: 0
+ kwargs (dict): Other arguments for initialization function.
+ """
+ if not isinstance(module_list, list):
+ module_list = [module_list]
+ for module in module_list:
+ for m in module.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, _BatchNorm):
+ init.constant_(m.weight, 1)
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+
+ Args:
+ basic_block (nn.module): nn.module class for basic block.
+ num_basic_block (int): number of blocks.
+
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+class PixelShufflePack(nn.Module):
+ """Pixel Shuffle upsample layer.
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ scale_factor (int): Upsample ratio.
+ upsample_kernel (int): Kernel size of Conv layer to expand channels.
+ Returns:
+ Upsampled feature map.
+ """
+
+ def __init__(self, in_channels, out_channels, scale_factor,
+ upsample_kernel):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+ self.upsample_kernel = upsample_kernel
+ self.upsample_conv = nn.Conv2d(
+ self.in_channels,
+ self.out_channels * scale_factor * scale_factor,
+ self.upsample_kernel,
+ padding=(self.upsample_kernel - 1) // 2)
+ self.init_weights()
+
+ def init_weights(self):
+ """Initialize weights for PixelShufflePack."""
+ default_init_weights(self, 1)
+
+ def forward(self, x):
+ """Forward function for PixelShufflePack.
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ Returns:
+ Tensor: Forward results.
+ """
+ x = self.upsample_conv(x)
+ x = F.pixel_shuffle(x, self.scale_factor)
+ return x
+
+class ResidualBlockNoBN(nn.Module):
+ """Residual block without BN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ res_scale (float): Residual scale. Default: 1.
+ pytorch_init (bool): If set to True, use pytorch default init,
+ otherwise, use default_init_weights. Default: False.
+ """
+
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
+ super(ResidualBlockNoBN, self).__init__()
+ self.res_scale = res_scale
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.relu = nn.ReLU(inplace=True)
+
+ if not pytorch_init:
+ default_init_weights([self.conv1, self.conv2], 0.1)
+
+ def forward(self, x):
+ identity = x
+ out = self.conv2(self.relu(self.conv1(x)))
+ return identity + out * self.res_scale
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
+ """Warp an image or feature map with optical flow.
+
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Before pytorch 1.3, the default value is
+ align_corners=True. After pytorch 1.3, the default value is
+ align_corners=False. Here, we use the True as default.
+
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ assert x.size()[-2:] == flow.size()[1:3]
+ _, _, h, w = x.size()
+ # create mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
+ grid.requires_grad = False
+
+ vgrid = grid + flow
+ # scale grid to [-1,1]
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
+
+ # TODO, what if align_corners=False
+ return output
+
+
+def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
+ """Resize a flow according to ratio or shape.
+
+ Args:
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
+ size_type (str): 'ratio' or 'shape'.
+ sizes (list[int | float]): the ratio for resizing or the final output
+ shape.
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
+ ratio > 1.0).
+ 2) The order of output_size should be [out_h, out_w].
+ interp_mode (str): The mode of interpolation for resizing.
+ Default: 'bilinear'.
+ align_corners (bool): Whether align corners. Default: False.
+
+ Returns:
+ Tensor: Resized flow.
+ """
+ _, _, flow_h, flow_w = flow.size()
+ if size_type == 'ratio':
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
+ elif size_type == 'shape':
+ output_h, output_w = sizes[0], sizes[1]
+ else:
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
+
+ input_flow = flow.clone()
+ ratio_h = output_h / flow_h
+ ratio_w = output_w / flow_w
+ input_flow[:, 0, :, :] *= ratio_w
+ input_flow[:, 1, :, :] *= ratio_h
+ resized_flow = F.interpolate(
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
+ return resized_flow
+
+
+# TODO: may write a cpp file
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+class DCNv2Pack(ModulatedDeformConvPack):
+ """Modulated deformable conv for deformable alignment.
+
+ Different from the official DCNv2Pack, which generates offsets and masks
+ from the preceding features, this DCNv2Pack takes another different
+ features to generate offsets and masks.
+
+ ``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution``
+ """
+
+ def forward(self, x, feat):
+ out = self.conv_offset(feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+
+ offset_absmean = torch.mean(torch.abs(offset))
+ if offset_absmean > 50:
+ logger = get_root_logger()
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
+
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, mask)
+ else:
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ low = norm_cdf((a - mean) / std)
+ up = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [low, up], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution.
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+
+ The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
diff --git a/StableSR/basicsr/archs/basicvsr_arch.py b/StableSR/basicsr/archs/basicvsr_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed7b824eae108a9bcca57f1c14dd0d8afafc4f58
--- /dev/null
+++ b/StableSR/basicsr/archs/basicvsr_arch.py
@@ -0,0 +1,336 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
+from .edvr_arch import PCDAlignment, TSAFusion
+from .spynet_arch import SpyNet
+
+
+@ARCH_REGISTRY.register()
+class BasicVSR(nn.Module):
+ """A recurrent network for video SR. Now only x4 is supported.
+
+ Args:
+ num_feat (int): Number of channels. Default: 64.
+ num_block (int): Number of residual blocks for each branch. Default: 15
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
+ """
+
+ def __init__(self, num_feat=64, num_block=15, spynet_path=None):
+ super().__init__()
+ self.num_feat = num_feat
+
+ # alignment
+ self.spynet = SpyNet(spynet_path)
+
+ # propagation
+ self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
+ self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
+
+ # reconstruction
+ self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ # activation functions
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def get_flow(self, x):
+ b, n, c, h, w = x.size()
+
+ x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
+ x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+ flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
+ flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
+
+ return flows_forward, flows_backward
+
+ def forward(self, x):
+ """Forward function of BasicVSR.
+
+ Args:
+ x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
+ """
+ flows_forward, flows_backward = self.get_flow(x)
+ b, n, _, h, w = x.size()
+
+ # backward branch
+ out_l = []
+ feat_prop = x.new_zeros(b, self.num_feat, h, w)
+ for i in range(n - 1, -1, -1):
+ x_i = x[:, i, :, :, :]
+ if i < n - 1:
+ flow = flows_backward[:, i, :, :, :]
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
+ feat_prop = self.backward_trunk(feat_prop)
+ out_l.insert(0, feat_prop)
+
+ # forward branch
+ feat_prop = torch.zeros_like(feat_prop)
+ for i in range(0, n):
+ x_i = x[:, i, :, :, :]
+ if i > 0:
+ flow = flows_forward[:, i - 1, :, :, :]
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
+ feat_prop = self.forward_trunk(feat_prop)
+
+ # upsample
+ out = torch.cat([out_l[i], feat_prop], dim=1)
+ out = self.lrelu(self.fusion(out))
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ out = self.lrelu(self.conv_hr(out))
+ out = self.conv_last(out)
+ base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
+ out += base
+ out_l[i] = out
+
+ return torch.stack(out_l, dim=1)
+
+
+class ConvResidualBlocks(nn.Module):
+ """Conv and residual block used in BasicVSR.
+
+ Args:
+ num_in_ch (int): Number of input channels. Default: 3.
+ num_out_ch (int): Number of output channels. Default: 64.
+ num_block (int): Number of residual blocks. Default: 15.
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
+ super().__init__()
+ self.main = nn.Sequential(
+ nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
+
+ def forward(self, fea):
+ return self.main(fea)
+
+
+@ARCH_REGISTRY.register()
+class IconVSR(nn.Module):
+ """IconVSR, proposed also in the BasicVSR paper.
+
+ Args:
+ num_feat (int): Number of channels. Default: 64.
+ num_block (int): Number of residual blocks for each branch. Default: 15.
+ keyframe_stride (int): Keyframe stride. Default: 5.
+ temporal_padding (int): Temporal padding. Default: 2.
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
+ edvr_path (str): Path to the pretrained EDVR model. Default: None.
+ """
+
+ def __init__(self,
+ num_feat=64,
+ num_block=15,
+ keyframe_stride=5,
+ temporal_padding=2,
+ spynet_path=None,
+ edvr_path=None):
+ super().__init__()
+
+ self.num_feat = num_feat
+ self.temporal_padding = temporal_padding
+ self.keyframe_stride = keyframe_stride
+
+ # keyframe_branch
+ self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
+ # alignment
+ self.spynet = SpyNet(spynet_path)
+
+ # propagation
+ self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
+ self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
+
+ self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
+ self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
+
+ # reconstruction
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ # activation functions
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def pad_spatial(self, x):
+ """Apply padding spatially.
+
+ Since the PCD module in EDVR requires that the resolution is a multiple
+ of 4, we apply padding to the input LR images if their resolution is
+ not divisible by 4.
+
+ Args:
+ x (Tensor): Input LR sequence with shape (n, t, c, h, w).
+ Returns:
+ Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
+ """
+ n, t, c, h, w = x.size()
+
+ pad_h = (4 - h % 4) % 4
+ pad_w = (4 - w % 4) % 4
+
+ # padding
+ x = x.view(-1, c, h, w)
+ x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
+
+ return x.view(n, t, c, h + pad_h, w + pad_w)
+
+ def get_flow(self, x):
+ b, n, c, h, w = x.size()
+
+ x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
+ x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+ flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
+ flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
+
+ return flows_forward, flows_backward
+
+ def get_keyframe_feature(self, x, keyframe_idx):
+ if self.temporal_padding == 2:
+ x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
+ elif self.temporal_padding == 3:
+ x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
+ x = torch.cat(x, dim=1)
+
+ num_frames = 2 * self.temporal_padding + 1
+ feats_keyframe = {}
+ for i in keyframe_idx:
+ feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
+ return feats_keyframe
+
+ def forward(self, x):
+ b, n, _, h_input, w_input = x.size()
+
+ x = self.pad_spatial(x)
+ h, w = x.shape[3:]
+
+ keyframe_idx = list(range(0, n, self.keyframe_stride))
+ if keyframe_idx[-1] != n - 1:
+ keyframe_idx.append(n - 1) # last frame is a keyframe
+
+ # compute flow and keyframe features
+ flows_forward, flows_backward = self.get_flow(x)
+ feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
+
+ # backward branch
+ out_l = []
+ feat_prop = x.new_zeros(b, self.num_feat, h, w)
+ for i in range(n - 1, -1, -1):
+ x_i = x[:, i, :, :, :]
+ if i < n - 1:
+ flow = flows_backward[:, i, :, :, :]
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+ if i in keyframe_idx:
+ feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
+ feat_prop = self.backward_fusion(feat_prop)
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
+ feat_prop = self.backward_trunk(feat_prop)
+ out_l.insert(0, feat_prop)
+
+ # forward branch
+ feat_prop = torch.zeros_like(feat_prop)
+ for i in range(0, n):
+ x_i = x[:, i, :, :, :]
+ if i > 0:
+ flow = flows_forward[:, i - 1, :, :, :]
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
+ if i in keyframe_idx:
+ feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
+ feat_prop = self.forward_fusion(feat_prop)
+
+ feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
+ feat_prop = self.forward_trunk(feat_prop)
+
+ # upsample
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ out = self.lrelu(self.conv_hr(out))
+ out = self.conv_last(out)
+ base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
+ out += base
+ out_l[i] = out
+
+ return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
+
+
+class EDVRFeatureExtractor(nn.Module):
+ """EDVR feature extractor used in IconVSR.
+
+ Args:
+ num_input_frame (int): Number of input frames.
+ num_feat (int): Number of feature channels
+ load_path (str): Path to the pretrained weights of EDVR. Default: None.
+ """
+
+ def __init__(self, num_input_frame, num_feat, load_path):
+
+ super(EDVRFeatureExtractor, self).__init__()
+
+ self.center_frame_idx = num_input_frame // 2
+
+ # extract pyramid features
+ self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
+ self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
+ self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+
+ # pcd and tsa module
+ self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
+ self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ if load_path:
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+ def forward(self, x):
+ b, n, c, h, w = x.size()
+
+ # extract features for each frame
+ # L1
+ feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
+ feat_l1 = self.feature_extraction(feat_l1)
+ # L2
+ feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
+ feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
+ # L3
+ feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
+ feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
+
+ feat_l1 = feat_l1.view(b, n, -1, h, w)
+ feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
+ feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
+
+ # PCD alignment
+ ref_feat_l = [ # reference feature list
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
+ feat_l3[:, self.center_frame_idx, :, :, :].clone()
+ ]
+ aligned_feat = []
+ for i in range(n):
+ nbr_feat_l = [ # neighboring feature list
+ feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
+ ]
+ aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
+ aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
+
+ # TSA fusion
+ return self.fusion(aligned_feat)
diff --git a/StableSR/basicsr/archs/basicvsrpp_arch.py b/StableSR/basicsr/archs/basicvsrpp_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a9952e4b441de0030d665a3db141774184f332f
--- /dev/null
+++ b/StableSR/basicsr/archs/basicvsrpp_arch.py
@@ -0,0 +1,417 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+import warnings
+
+from basicsr.archs.arch_util import flow_warp
+from basicsr.archs.basicvsr_arch import ConvResidualBlocks
+from basicsr.archs.spynet_arch import SpyNet
+from basicsr.ops.dcn import ModulatedDeformConvPack
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register()
+class BasicVSRPlusPlus(nn.Module):
+ """BasicVSR++ network structure.
+
+ Support either x4 upsampling or same size output. Since DCN is used in this
+ model, it can only be used with CUDA enabled. If CUDA is not enabled,
+ feature alignment will be skipped. Besides, we adopt the official DCN
+ implementation and the version of torch need to be higher than 1.9.
+
+ ``Paper: BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment``
+
+ Args:
+ mid_channels (int, optional): Channel number of the intermediate
+ features. Default: 64.
+ num_blocks (int, optional): The number of residual blocks in each
+ propagation branch. Default: 7.
+ max_residue_magnitude (int): The maximum magnitude of the offset
+ residue (Eq. 6 in paper). Default: 10.
+ is_low_res_input (bool, optional): Whether the input is low-resolution
+ or not. If False, the output resolution is equal to the input
+ resolution. Default: True.
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
+ cpu_cache_length (int, optional): When the length of sequence is larger
+ than this value, the intermediate features are sent to CPU. This
+ saves GPU memory, but slows down the inference speed. You can
+ increase this number if you have a GPU with large memory.
+ Default: 100.
+ """
+
+ def __init__(self,
+ mid_channels=64,
+ num_blocks=7,
+ max_residue_magnitude=10,
+ is_low_res_input=True,
+ spynet_path=None,
+ cpu_cache_length=100):
+
+ super().__init__()
+ self.mid_channels = mid_channels
+ self.is_low_res_input = is_low_res_input
+ self.cpu_cache_length = cpu_cache_length
+
+ # optical flow
+ self.spynet = SpyNet(spynet_path)
+
+ # feature extraction module
+ if is_low_res_input:
+ self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
+ else:
+ self.feat_extract = nn.Sequential(
+ nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ ConvResidualBlocks(mid_channels, mid_channels, 5))
+
+ # propagation branches
+ self.deform_align = nn.ModuleDict()
+ self.backbone = nn.ModuleDict()
+ modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
+ for i, module in enumerate(modules):
+ if torch.cuda.is_available():
+ self.deform_align[module] = SecondOrderDeformableAlignment(
+ 2 * mid_channels,
+ mid_channels,
+ 3,
+ padding=1,
+ deformable_groups=16,
+ max_residue_magnitude=max_residue_magnitude)
+ self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
+
+ # upsampling module
+ self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
+
+ self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
+ self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
+
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+ self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ # check if the sequence is augmented by flipping
+ self.is_mirror_extended = False
+
+ if len(self.deform_align) > 0:
+ self.is_with_alignment = True
+ else:
+ self.is_with_alignment = False
+ warnings.warn('Deformable alignment module is not added. '
+ 'Probably your CUDA is not configured correctly. DCN can only '
+ 'be used with CUDA enabled. Alignment is skipped now.')
+
+ def check_if_mirror_extended(self, lqs):
+ """Check whether the input is a mirror-extended sequence.
+
+ If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the (t-1-i)-th frame.
+
+ Args:
+ lqs (tensor): Input low quality (LQ) sequence with shape (n, t, c, h, w).
+ """
+
+ if lqs.size(1) % 2 == 0:
+ lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
+ if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
+ self.is_mirror_extended = True
+
+ def compute_flow(self, lqs):
+ """Compute optical flow using SPyNet for feature alignment.
+
+ Note that if the input is an mirror-extended sequence, 'flows_forward'
+ is not needed, since it is equal to 'flows_backward.flip(1)'.
+
+ Args:
+ lqs (tensor): Input low quality (LQ) sequence with
+ shape (n, t, c, h, w).
+
+ Return:
+ tuple(Tensor): Optical flow. 'flows_forward' corresponds to the flows used for forward-time propagation \
+ (current to previous). 'flows_backward' corresponds to the flows used for backward-time \
+ propagation (current to next).
+ """
+
+ n, t, c, h, w = lqs.size()
+ lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
+ lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
+
+ flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
+
+ if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
+ flows_forward = flows_backward.flip(1)
+ else:
+ flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
+
+ if self.cpu_cache:
+ flows_backward = flows_backward.cpu()
+ flows_forward = flows_forward.cpu()
+
+ return flows_forward, flows_backward
+
+ def propagate(self, feats, flows, module_name):
+ """Propagate the latent features throughout the sequence.
+
+ Args:
+ feats dict(list[tensor]): Features from previous branches. Each
+ component is a list of tensors with shape (n, c, h, w).
+ flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
+ module_name (str): The name of the propgation branches. Can either
+ be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
+
+ Return:
+ dict(list[tensor]): A dictionary containing all the propagated \
+ features. Each key in the dictionary corresponds to a \
+ propagation branch, which is represented by a list of tensors.
+ """
+
+ n, t, _, h, w = flows.size()
+
+ frame_idx = range(0, t + 1)
+ flow_idx = range(-1, t)
+ mapping_idx = list(range(0, len(feats['spatial'])))
+ mapping_idx += mapping_idx[::-1]
+
+ if 'backward' in module_name:
+ frame_idx = frame_idx[::-1]
+ flow_idx = frame_idx
+
+ feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
+ for i, idx in enumerate(frame_idx):
+ feat_current = feats['spatial'][mapping_idx[idx]]
+ if self.cpu_cache:
+ feat_current = feat_current.cuda()
+ feat_prop = feat_prop.cuda()
+ # second-order deformable alignment
+ if i > 0 and self.is_with_alignment:
+ flow_n1 = flows[:, flow_idx[i], :, :, :]
+ if self.cpu_cache:
+ flow_n1 = flow_n1.cuda()
+
+ cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
+
+ # initialize second-order features
+ feat_n2 = torch.zeros_like(feat_prop)
+ flow_n2 = torch.zeros_like(flow_n1)
+ cond_n2 = torch.zeros_like(cond_n1)
+
+ if i > 1: # second-order features
+ feat_n2 = feats[module_name][-2]
+ if self.cpu_cache:
+ feat_n2 = feat_n2.cuda()
+
+ flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
+ if self.cpu_cache:
+ flow_n2 = flow_n2.cuda()
+
+ flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
+ cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
+
+ # flow-guided deformable convolution
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
+ feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
+
+ # concatenate and residual blocks
+ feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
+ if self.cpu_cache:
+ feat = [f.cuda() for f in feat]
+
+ feat = torch.cat(feat, dim=1)
+ feat_prop = feat_prop + self.backbone[module_name](feat)
+ feats[module_name].append(feat_prop)
+
+ if self.cpu_cache:
+ feats[module_name][-1] = feats[module_name][-1].cpu()
+ torch.cuda.empty_cache()
+
+ if 'backward' in module_name:
+ feats[module_name] = feats[module_name][::-1]
+
+ return feats
+
+ def upsample(self, lqs, feats):
+ """Compute the output image given the features.
+
+ Args:
+ lqs (tensor): Input low quality (LQ) sequence with
+ shape (n, t, c, h, w).
+ feats (dict): The features from the propagation branches.
+
+ Returns:
+ Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
+ """
+
+ outputs = []
+ num_outputs = len(feats['spatial'])
+
+ mapping_idx = list(range(0, num_outputs))
+ mapping_idx += mapping_idx[::-1]
+
+ for i in range(0, lqs.size(1)):
+ hr = [feats[k].pop(0) for k in feats if k != 'spatial']
+ hr.insert(0, feats['spatial'][mapping_idx[i]])
+ hr = torch.cat(hr, dim=1)
+ if self.cpu_cache:
+ hr = hr.cuda()
+
+ hr = self.reconstruction(hr)
+ hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
+ hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
+ hr = self.lrelu(self.conv_hr(hr))
+ hr = self.conv_last(hr)
+ if self.is_low_res_input:
+ hr += self.img_upsample(lqs[:, i, :, :, :])
+ else:
+ hr += lqs[:, i, :, :, :]
+
+ if self.cpu_cache:
+ hr = hr.cpu()
+ torch.cuda.empty_cache()
+
+ outputs.append(hr)
+
+ return torch.stack(outputs, dim=1)
+
+ def forward(self, lqs):
+ """Forward function for BasicVSR++.
+
+ Args:
+ lqs (tensor): Input low quality (LQ) sequence with
+ shape (n, t, c, h, w).
+
+ Returns:
+ Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
+ """
+
+ n, t, c, h, w = lqs.size()
+
+ # whether to cache the features in CPU
+ self.cpu_cache = True if t > self.cpu_cache_length else False
+
+ if self.is_low_res_input:
+ lqs_downsample = lqs.clone()
+ else:
+ lqs_downsample = F.interpolate(
+ lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
+
+ # check whether the input is an extended sequence
+ self.check_if_mirror_extended(lqs)
+
+ feats = {}
+ # compute spatial features
+ if self.cpu_cache:
+ feats['spatial'] = []
+ for i in range(0, t):
+ feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
+ feats['spatial'].append(feat)
+ torch.cuda.empty_cache()
+ else:
+ feats_ = self.feat_extract(lqs.view(-1, c, h, w))
+ h, w = feats_.shape[2:]
+ feats_ = feats_.view(n, t, -1, h, w)
+ feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
+
+ # compute optical flow using the low-res inputs
+ assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
+ 'The height and width of low-res inputs must be at least 64, '
+ f'but got {h} and {w}.')
+ flows_forward, flows_backward = self.compute_flow(lqs_downsample)
+
+ # feature propgation
+ for iter_ in [1, 2]:
+ for direction in ['backward', 'forward']:
+ module = f'{direction}_{iter_}'
+
+ feats[module] = []
+
+ if direction == 'backward':
+ flows = flows_backward
+ elif flows_forward is not None:
+ flows = flows_forward
+ else:
+ flows = flows_backward.flip(1)
+
+ feats = self.propagate(feats, flows, module)
+ if self.cpu_cache:
+ del flows
+ torch.cuda.empty_cache()
+
+ return self.upsample(lqs, feats)
+
+
+class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
+ """Second-order deformable alignment module.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ max_residue_magnitude (int): The maximum magnitude of the offset
+ residue (Eq. 6 in paper). Default: 10.
+ """
+
+ def __init__(self, *args, **kwargs):
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
+
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Sequential(
+ nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
+ )
+
+ self.init_offset()
+
+ def init_offset(self):
+
+ def _constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+ _constant_init(self.conv_offset[-1], val=0, bias=0)
+
+ def forward(self, x, extra_feat, flow_1, flow_2):
+ extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
+ out = self.conv_offset(extra_feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+
+ # offset
+ offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
+ offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
+ offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
+ offset = torch.cat([offset_1, offset_2], dim=1)
+
+ # mask
+ mask = torch.sigmoid(mask)
+
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, mask)
+
+
+# if __name__ == '__main__':
+# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
+# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
+# input = torch.rand(1, 2, 3, 64, 64).cuda()
+# output = model(input)
+# print('===================')
+# print(output.shape)
diff --git a/StableSR/basicsr/archs/degradat_arch.py b/StableSR/basicsr/archs/degradat_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce09ad666a90f175fb6268435073b314df543813
--- /dev/null
+++ b/StableSR/basicsr/archs/degradat_arch.py
@@ -0,0 +1,90 @@
+from torch import nn as nn
+
+from basicsr.archs.arch_util import ResidualBlockNoBN, default_init_weights
+from basicsr.utils.registry import ARCH_REGISTRY
+
+@ARCH_REGISTRY.register()
+class DEResNet(nn.Module):
+ """Degradation Estimator with ResNetNoBN arch. v2.1, no vector anymore
+ As shown in paper 'Towards Flexible Blind JPEG Artifacts Removal',
+ resnet arch works for image quality estimation.
+ Args:
+ num_in_ch (int): channel number of inputs. Default: 3.
+ num_degradation (int): num of degradation the DE should estimate. Default: 2(blur+noise).
+ degradation_embed_size (int): embedding size of each degradation vector.
+ degradation_degree_actv (int): activation function for degradation degree scalar. Default: sigmoid.
+ num_feats (list): channel number of each stage.
+ num_blocks (list): residual block of each stage.
+ downscales (list): downscales of each stage.
+ """
+
+ def __init__(self,
+ num_in_ch=3,
+ num_degradation=2,
+ degradation_degree_actv='sigmoid',
+ num_feats=(64, 128, 256, 512),
+ num_blocks=(2, 2, 2, 2),
+ downscales=(2, 2, 2, 1)):
+ super(DEResNet, self).__init__()
+
+ assert isinstance(num_feats, list)
+ assert isinstance(num_blocks, list)
+ assert isinstance(downscales, list)
+ assert len(num_feats) == len(num_blocks) and len(num_feats) == len(downscales)
+
+ num_stage = len(num_feats)
+
+ self.conv_first = nn.ModuleList()
+ for _ in range(num_degradation):
+ self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1))
+ self.body = nn.ModuleList()
+ for _ in range(num_degradation):
+ body = list()
+ for stage in range(num_stage):
+ for _ in range(num_blocks[stage]):
+ body.append(ResidualBlockNoBN(num_feats[stage]))
+ if downscales[stage] == 1:
+ if stage < num_stage - 1 and num_feats[stage] != num_feats[stage + 1]:
+ body.append(nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1))
+ continue
+ elif downscales[stage] == 2:
+ body.append(nn.Conv2d(num_feats[stage], num_feats[min(stage + 1, num_stage - 1)], 3, 2, 1))
+ else:
+ raise NotImplementedError
+ self.body.append(nn.Sequential(*body))
+
+ # self.body = nn.Sequential(*body)
+
+ self.num_degradation = num_degradation
+ self.fc_degree = nn.ModuleList()
+ if degradation_degree_actv == 'sigmoid':
+ actv = nn.Sigmoid
+ elif degradation_degree_actv == 'tanh':
+ actv = nn.Tanh
+ else:
+ raise NotImplementedError(f'only sigmoid and tanh are supported for degradation_degree_actv, '
+ f'{degradation_degree_actv} is not supported yet.')
+ for _ in range(num_degradation):
+ self.fc_degree.append(
+ nn.Sequential(
+ nn.Linear(num_feats[-1], 512),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, 1),
+ actv(),
+ ))
+
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+
+ default_init_weights([self.conv_first, self.body, self.fc_degree], 0.1)
+
+ def forward(self, x):
+ degrees = []
+ for i in range(self.num_degradation):
+ x_out = self.conv_first[i](x)
+ feat = self.body[i](x_out)
+ feat = self.avg_pool(feat)
+ feat = feat.squeeze(-1).squeeze(-1)
+ # for i in range(self.num_degradation):
+ degrees.append(self.fc_degree[i](feat).squeeze(-1))
+
+ return degrees
diff --git a/StableSR/basicsr/archs/dfdnet_arch.py b/StableSR/basicsr/archs/dfdnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4751434c2f17efbb682d9344951604602d853aaa
--- /dev/null
+++ b/StableSR/basicsr/archs/dfdnet_arch.py
@@ -0,0 +1,169 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.utils.spectral_norm import spectral_norm
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
+from .vgg_arch import VGGFeatureExtractor
+
+
+class SFTUpBlock(nn.Module):
+ """Spatial feature transform (SFT) with upsampling block.
+
+ Args:
+ in_channel (int): Number of input channels.
+ out_channel (int): Number of output channels.
+ kernel_size (int): Kernel size in convolutions. Default: 3.
+ padding (int): Padding in convolutions. Default: 1.
+ """
+
+ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
+ super(SFTUpBlock, self).__init__()
+ self.conv1 = nn.Sequential(
+ Blur(in_channel),
+ spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
+ nn.LeakyReLU(0.04, True),
+ # The official codes use two LeakyReLU here, so 0.04 for equivalent
+ )
+ self.convup = nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
+ nn.LeakyReLU(0.2, True),
+ )
+
+ # for SFT scale and shift
+ self.scale_block = nn.Sequential(
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
+ self.shift_block = nn.Sequential(
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
+ # The official codes use sigmoid for shift block, do not know why
+
+ def forward(self, x, updated_feat):
+ out = self.conv1(x)
+ # SFT
+ scale = self.scale_block(updated_feat)
+ shift = self.shift_block(updated_feat)
+ out = out * scale + shift
+ # upsample
+ out = self.convup(out)
+ return out
+
+
+@ARCH_REGISTRY.register()
+class DFDNet(nn.Module):
+ """DFDNet: Deep Face Dictionary Network.
+
+ It only processes faces with 512x512 size.
+
+ Args:
+ num_feat (int): Number of feature channels.
+ dict_path (str): Path to the facial component dictionary.
+ """
+
+ def __init__(self, num_feat, dict_path):
+ super().__init__()
+ self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
+ # part_sizes: [80, 80, 50, 110]
+ channel_sizes = [128, 256, 512, 512]
+ self.feature_sizes = np.array([256, 128, 64, 32])
+ self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
+ self.flag_dict_device = False
+
+ # dict
+ self.dict = torch.load(dict_path)
+
+ # vgg face extractor
+ self.vgg_extractor = VGGFeatureExtractor(
+ layer_name_list=self.vgg_layers,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=True,
+ requires_grad=False)
+
+ # attention block for fusing dictionary features and input features
+ self.attn_blocks = nn.ModuleDict()
+ for idx, feat_size in enumerate(self.feature_sizes):
+ for name in self.parts:
+ self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
+
+ # multi scale dilation block
+ self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
+
+ # upsampling and reconstruction
+ self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
+ self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
+ self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
+ self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
+ self.upsample4 = nn.Sequential(
+ spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
+ UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
+
+ def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
+ """swap the features from the dictionary."""
+ # get the original vgg features
+ part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
+ # resize original vgg features
+ part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
+ # use adaptive instance normalization to adjust color and illuminations
+ dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
+ # get similarity scores
+ similarity_score = F.conv2d(part_resize_feat, dict_feat)
+ similarity_score = F.softmax(similarity_score.view(-1), dim=0)
+ # select the most similar features in the dict (after norm)
+ select_idx = torch.argmax(similarity_score)
+ swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
+ # attention
+ attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
+ attn_feat = attn * swap_feat
+ # update features
+ updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
+ return updated_feat
+
+ def put_dict_to_device(self, x):
+ if self.flag_dict_device is False:
+ for k, v in self.dict.items():
+ for kk, vv in v.items():
+ self.dict[k][kk] = vv.to(x)
+ self.flag_dict_device = True
+
+ def forward(self, x, part_locations):
+ """
+ Now only support testing with batch size = 0.
+
+ Args:
+ x (Tensor): Input faces with shape (b, c, 512, 512).
+ part_locations (list[Tensor]): Part locations.
+ """
+ self.put_dict_to_device(x)
+ # extract vggface features
+ vgg_features = self.vgg_extractor(x)
+ # update vggface features using the dictionary for each part
+ updated_vgg_features = []
+ batch = 0 # only supports testing with batch size = 0
+ for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
+ dict_features = self.dict[f'{f_size}']
+ vgg_feat = vgg_features[vgg_layer]
+ updated_feat = vgg_feat.clone()
+
+ # swap features from dictionary
+ for part_idx, part_name in enumerate(self.parts):
+ location = (part_locations[part_idx][batch] // (512 / f_size)).int()
+ updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
+ f_size)
+
+ updated_vgg_features.append(updated_feat)
+
+ vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
+ # use updated vgg features to modulate the upsampled features with
+ # SFT (Spatial Feature Transform) scaling and shifting manner.
+ upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
+ upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
+ upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
+ upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
+ out = self.upsample4(upsampled_feat)
+
+ return out
diff --git a/StableSR/basicsr/archs/dfdnet_util.py b/StableSR/basicsr/archs/dfdnet_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4dc0ff738c76852e830b32fffbe65bffb5ddf50
--- /dev/null
+++ b/StableSR/basicsr/archs/dfdnet_util.py
@@ -0,0 +1,162 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.nn.utils.spectral_norm import spectral_norm
+
+
+class BlurFunctionBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, kernel_flip):
+ ctx.save_for_backward(kernel, kernel_flip)
+ grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_output):
+ kernel, _ = ctx.saved_tensors
+ grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
+ return grad_input, None, None
+
+
+class BlurFunction(Function):
+
+ @staticmethod
+ def forward(ctx, x, kernel, kernel_flip):
+ ctx.save_for_backward(kernel, kernel_flip)
+ output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, kernel_flip = ctx.saved_tensors
+ grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
+ return grad_input, None, None
+
+
+blur = BlurFunction.apply
+
+
+class Blur(nn.Module):
+
+ def __init__(self, channel):
+ super().__init__()
+ kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
+ kernel = kernel.view(1, 1, 3, 3)
+ kernel = kernel / kernel.sum()
+ kernel_flip = torch.flip(kernel, [2, 3])
+
+ self.kernel = kernel.repeat(channel, 1, 1, 1)
+ self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
+
+ def forward(self, x):
+ return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
+
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ n, c = size[:2]
+ feat_var = feat.view(n, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(n, c, 1, 1)
+ feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+def AttentionBlock(in_channel):
+ return nn.Sequential(
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
+
+
+def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
+ """Conv block used in MSDilationBlock."""
+
+ return nn.Sequential(
+ spectral_norm(
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=((kernel_size - 1) // 2) * dilation,
+ bias=bias)),
+ nn.LeakyReLU(0.2),
+ spectral_norm(
+ nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=((kernel_size - 1) // 2) * dilation,
+ bias=bias)),
+ )
+
+
+class MSDilationBlock(nn.Module):
+ """Multi-scale dilation block."""
+
+ def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
+ super(MSDilationBlock, self).__init__()
+
+ self.conv_blocks = nn.ModuleList()
+ for i in range(4):
+ self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
+ self.conv_fusion = spectral_norm(
+ nn.Conv2d(
+ in_channels * 4,
+ in_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ bias=bias))
+
+ def forward(self, x):
+ out = []
+ for i in range(4):
+ out.append(self.conv_blocks[i](x))
+ out = torch.cat(out, 1)
+ out = self.conv_fusion(out) + x
+ return out
+
+
+class UpResBlock(nn.Module):
+
+ def __init__(self, in_channel):
+ super(UpResBlock, self).__init__()
+ self.body = nn.Sequential(
+ nn.Conv2d(in_channel, in_channel, 3, 1, 1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(in_channel, in_channel, 3, 1, 1),
+ )
+
+ def forward(self, x):
+ out = x + self.body(x)
+ return out
diff --git a/StableSR/basicsr/archs/discriminator_arch.py b/StableSR/basicsr/archs/discriminator_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f9a8f1b25c2052cd3ba801534861a425752e69
--- /dev/null
+++ b/StableSR/basicsr/archs/discriminator_arch.py
@@ -0,0 +1,150 @@
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register()
+class VGGStyleDiscriminator(nn.Module):
+ """VGG style discriminator with input size 128 x 128 or 256 x 256.
+
+ It is used to train SRGAN, ESRGAN, and VideoGAN.
+
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_feat (int): Channel number of base intermediate features.Default: 64.
+ """
+
+ def __init__(self, num_in_ch, num_feat, input_size=128):
+ super(VGGStyleDiscriminator, self).__init__()
+ self.input_size = input_size
+ assert self.input_size == 128 or self.input_size == 256, (
+ f'input size must be 128 or 256, but received {input_size}')
+
+ self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
+ self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
+ self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
+
+ self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
+ self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
+ self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
+ self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
+
+ self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
+ self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
+ self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
+ self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
+
+ self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
+ self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
+ self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
+ self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
+
+ self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
+ self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
+ self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
+ self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
+
+ if self.input_size == 256:
+ self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
+ self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
+ self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
+ self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
+
+ self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
+ self.linear2 = nn.Linear(100, 1)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
+
+ feat = self.lrelu(self.conv0_0(x))
+ feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
+
+ feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
+ feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
+
+ feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
+ feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
+
+ feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
+ feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
+
+ feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
+ feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
+
+ if self.input_size == 256:
+ feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
+ feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
+
+ # spatial size: (4, 4)
+ feat = feat.view(feat.size(0), -1)
+ feat = self.lrelu(self.linear1(feat))
+ out = self.linear2(feat)
+ return out
+
+
+@ARCH_REGISTRY.register(suffix='basicsr')
+class UNetDiscriminatorSN(nn.Module):
+ """Defines a U-Net discriminator with spectral normalization (SN)
+
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ Arg:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_feat (int): Channel number of base intermediate features. Default: 64.
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
+ """
+
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
+ super(UNetDiscriminatorSN, self).__init__()
+ self.skip_connection = skip_connection
+ norm = spectral_norm
+ # the first convolution
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
+ # downsample
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
+ # upsample
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
+ # extra convolutions
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
+
+ def forward(self, x):
+ # downsample
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
+
+ # upsample
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x4 = x4 + x2
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x5 = x5 + x1
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x6 = x6 + x0
+
+ # extra convolutions
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
+ out = self.conv9(out)
+
+ return out
diff --git a/StableSR/basicsr/archs/duf_arch.py b/StableSR/basicsr/archs/duf_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b3ab7df4d890c9220d74ed8c461ad9d155120a
--- /dev/null
+++ b/StableSR/basicsr/archs/duf_arch.py
@@ -0,0 +1,276 @@
+import numpy as np
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+class DenseBlocksTemporalReduce(nn.Module):
+ """A concatenation of 3 dense blocks with reduction in temporal dimension.
+
+ Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
+
+ Args:
+ num_feat (int): Number of channels in the blocks. Default: 64.
+ num_grow_ch (int): Growing factor of the dense blocks. Default: 32
+ adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
+ Set to false if you want to train from scratch. Default: False.
+ """
+
+ def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
+ super(DenseBlocksTemporalReduce, self).__init__()
+ if adapt_official_weights:
+ eps = 1e-3
+ momentum = 1e-3
+ else: # pytorch default values
+ eps = 1e-05
+ momentum = 0.1
+
+ self.temporal_reduce1 = nn.Sequential(
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+
+ self.temporal_reduce2 = nn.Sequential(
+ nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + num_grow_ch,
+ num_feat + num_grow_ch, (1, 1, 1),
+ stride=(1, 1, 1),
+ padding=(0, 0, 0),
+ bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+
+ self.temporal_reduce3 = nn.Sequential(
+ nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + 2 * num_grow_ch,
+ num_feat + 2 * num_grow_ch, (1, 1, 1),
+ stride=(1, 1, 1),
+ padding=(0, 0, 0),
+ bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
+
+ Returns:
+ Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
+ """
+ x1 = self.temporal_reduce1(x)
+ x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
+
+ x2 = self.temporal_reduce2(x1)
+ x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
+
+ x3 = self.temporal_reduce3(x2)
+ x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
+
+ return x3
+
+
+class DenseBlocks(nn.Module):
+ """ A concatenation of N dense blocks.
+
+ Args:
+ num_feat (int): Number of channels in the blocks. Default: 64.
+ num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
+ num_block (int): Number of dense blocks. The values are:
+ DUF-S (16 layers): 3
+ DUF-M (18 layers): 9
+ DUF-L (52 layers): 21
+ adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
+ Set to false if you want to train from scratch. Default: False.
+ """
+
+ def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
+ super(DenseBlocks, self).__init__()
+ if adapt_official_weights:
+ eps = 1e-3
+ momentum = 1e-3
+ else: # pytorch default values
+ eps = 1e-05
+ momentum = 0.1
+
+ self.dense_blocks = nn.ModuleList()
+ for i in range(0, num_block):
+ self.dense_blocks.append(
+ nn.Sequential(
+ nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + i * num_grow_ch,
+ num_feat + i * num_grow_ch, (1, 1, 1),
+ stride=(1, 1, 1),
+ padding=(0, 0, 0),
+ bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
+ nn.Conv3d(
+ num_feat + i * num_grow_ch,
+ num_grow_ch, (3, 3, 3),
+ stride=(1, 1, 1),
+ padding=(1, 1, 1),
+ bias=True)))
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
+
+ Returns:
+ Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
+ """
+ for i in range(0, len(self.dense_blocks)):
+ y = self.dense_blocks[i](x)
+ x = torch.cat((x, y), 1)
+ return x
+
+
+class DynamicUpsamplingFilter(nn.Module):
+ """Dynamic upsampling filter used in DUF.
+
+ Reference: https://github.com/yhjo09/VSR-DUF
+
+ It only supports input with 3 channels. And it applies the same filters to 3 channels.
+
+ Args:
+ filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
+ """
+
+ def __init__(self, filter_size=(5, 5)):
+ super(DynamicUpsamplingFilter, self).__init__()
+ if not isinstance(filter_size, tuple):
+ raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
+ if len(filter_size) != 2:
+ raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
+ # generate a local expansion filter, similar to im2col
+ self.filter_size = filter_size
+ filter_prod = np.prod(filter_size)
+ expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
+ self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
+
+ def forward(self, x, filters):
+ """Forward function for DynamicUpsamplingFilter.
+
+ Args:
+ x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
+ filters (Tensor): Generated dynamic filters. The shape is (n, filter_prod, upsampling_square, h, w).
+ filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
+ upsampling_square: similar to pixel shuffle, upsampling_square = upsampling * upsampling.
+ e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
+
+ Returns:
+ Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
+ """
+ n, filter_prod, upsampling_square, h, w = filters.size()
+ kh, kw = self.filter_size
+ expanded_input = F.conv2d(
+ x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
+ expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
+ 2) # (n, h, w, 3, filter_prod)
+ filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
+ out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
+ return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
+
+
+@ARCH_REGISTRY.register()
+class DUF(nn.Module):
+ """Network architecture for DUF
+
+ ``Paper: Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation``
+
+ Reference: https://github.com/yhjo09/VSR-DUF
+
+ For all the models below, 'adapt_official_weights' is only necessary when
+ loading the weights converted from the official TensorFlow weights.
+ Please set it to False if you are training the model from scratch.
+
+ There are three models with different model size: DUF16Layers, DUF28Layers,
+ and DUF52Layers. This class is the base class for these models.
+
+ Args:
+ scale (int): The upsampling factor. Default: 4.
+ num_layer (int): The number of layers. Default: 52.
+ adapt_official_weights_weights (bool): Whether to adapt the weights
+ translated from the official implementation. Set to false if you
+ want to train from scratch. Default: False.
+ """
+
+ def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
+ super(DUF, self).__init__()
+ self.scale = scale
+ if adapt_official_weights:
+ eps = 1e-3
+ momentum = 1e-3
+ else: # pytorch default values
+ eps = 1e-05
+ momentum = 0.1
+
+ self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
+ self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
+
+ if num_layer == 16:
+ num_block = 3
+ num_grow_ch = 32
+ elif num_layer == 28:
+ num_block = 9
+ num_grow_ch = 16
+ elif num_layer == 52:
+ num_block = 21
+ num_grow_ch = 16
+ else:
+ raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
+
+ self.dense_block1 = DenseBlocks(
+ num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
+ adapt_official_weights=adapt_official_weights) # T = 7
+ self.dense_block2 = DenseBlocksTemporalReduce(
+ 64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
+ channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
+ self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
+ self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
+
+ self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+ self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+
+ self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+ self.conv3d_f2 = nn.Conv3d(
+ 512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+
+ def forward(self, x):
+ """
+ Args:
+ x (Tensor): Input with shape (b, 7, c, h, w)
+
+ Returns:
+ Tensor: Output with shape (b, c, h * scale, w * scale)
+ """
+ num_batches, num_imgs, _, h, w = x.size()
+
+ x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
+ x_center = x[:, :, num_imgs // 2, :, :]
+
+ x = self.conv3d1(x)
+ x = self.dense_block1(x)
+ x = self.dense_block2(x)
+ x = F.relu(self.bn3d2(x), inplace=True)
+ x = F.relu(self.conv3d2(x), inplace=True)
+
+ # residual image
+ res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
+
+ # filter
+ filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
+ filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
+
+ # dynamic filter
+ out = self.dynamic_filter(x_center, filter_)
+ out += res.squeeze_(2)
+ out = F.pixel_shuffle(out, self.scale)
+
+ return out
diff --git a/StableSR/basicsr/archs/ecbsr_arch.py b/StableSR/basicsr/archs/ecbsr_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe20e772587d74c67fffb40f3b4731cf4f42268b
--- /dev/null
+++ b/StableSR/basicsr/archs/ecbsr_arch.py
@@ -0,0 +1,275 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+class SeqConv3x3(nn.Module):
+ """The re-parameterizable block used in the ECBSR architecture.
+
+ ``Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices``
+
+ Reference: https://github.com/xindongzhang/ECBSR
+
+ Args:
+ seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
+ in_channels (int): Channel number of input.
+ out_channels (int): Channel number of output.
+ depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
+ """
+
+ def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
+ super(SeqConv3x3, self).__init__()
+ self.seq_type = seq_type
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if self.seq_type == 'conv1x1-conv3x3':
+ self.mid_planes = int(out_channels * depth_multiplier)
+ conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
+ self.k0 = conv0.weight
+ self.b0 = conv0.bias
+
+ conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
+ self.k1 = conv1.weight
+ self.b1 = conv1.bias
+
+ elif self.seq_type == 'conv1x1-sobelx':
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
+ self.k0 = conv0.weight
+ self.b0 = conv0.bias
+
+ # init scale and bias
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
+ self.scale = nn.Parameter(scale)
+ bias = torch.randn(self.out_channels) * 1e-3
+ bias = torch.reshape(bias, (self.out_channels, ))
+ self.bias = nn.Parameter(bias)
+ # init mask
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
+ for i in range(self.out_channels):
+ self.mask[i, 0, 0, 0] = 1.0
+ self.mask[i, 0, 1, 0] = 2.0
+ self.mask[i, 0, 2, 0] = 1.0
+ self.mask[i, 0, 0, 2] = -1.0
+ self.mask[i, 0, 1, 2] = -2.0
+ self.mask[i, 0, 2, 2] = -1.0
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
+
+ elif self.seq_type == 'conv1x1-sobely':
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
+ self.k0 = conv0.weight
+ self.b0 = conv0.bias
+
+ # init scale and bias
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
+ bias = torch.randn(self.out_channels) * 1e-3
+ bias = torch.reshape(bias, (self.out_channels, ))
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
+ # init mask
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
+ for i in range(self.out_channels):
+ self.mask[i, 0, 0, 0] = 1.0
+ self.mask[i, 0, 0, 1] = 2.0
+ self.mask[i, 0, 0, 2] = 1.0
+ self.mask[i, 0, 2, 0] = -1.0
+ self.mask[i, 0, 2, 1] = -2.0
+ self.mask[i, 0, 2, 2] = -1.0
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
+
+ elif self.seq_type == 'conv1x1-laplacian':
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
+ self.k0 = conv0.weight
+ self.b0 = conv0.bias
+
+ # init scale and bias
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
+ bias = torch.randn(self.out_channels) * 1e-3
+ bias = torch.reshape(bias, (self.out_channels, ))
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
+ # init mask
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
+ for i in range(self.out_channels):
+ self.mask[i, 0, 0, 1] = 1.0
+ self.mask[i, 0, 1, 0] = 1.0
+ self.mask[i, 0, 1, 2] = 1.0
+ self.mask[i, 0, 2, 1] = 1.0
+ self.mask[i, 0, 1, 1] = -4.0
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
+ else:
+ raise ValueError('The type of seqconv is not supported!')
+
+ def forward(self, x):
+ if self.seq_type == 'conv1x1-conv3x3':
+ # conv-1x1
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
+ # explicitly padding with bias
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
+ b0_pad = self.b0.view(1, -1, 1, 1)
+ y0[:, :, 0:1, :] = b0_pad
+ y0[:, :, -1:, :] = b0_pad
+ y0[:, :, :, 0:1] = b0_pad
+ y0[:, :, :, -1:] = b0_pad
+ # conv-3x3
+ y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
+ else:
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
+ # explicitly padding with bias
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
+ b0_pad = self.b0.view(1, -1, 1, 1)
+ y0[:, :, 0:1, :] = b0_pad
+ y0[:, :, -1:, :] = b0_pad
+ y0[:, :, :, 0:1] = b0_pad
+ y0[:, :, :, -1:] = b0_pad
+ # conv-3x3
+ y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
+ return y1
+
+ def rep_params(self):
+ device = self.k0.get_device()
+ if device < 0:
+ device = None
+
+ if self.seq_type == 'conv1x1-conv3x3':
+ # re-param conv kernel
+ rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
+ # re-param conv bias
+ rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
+ rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
+ else:
+ tmp = self.scale * self.mask
+ k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
+ for i in range(self.out_channels):
+ k1[i, i, :, :] = tmp[i, 0, :, :]
+ b1 = self.bias
+ # re-param conv kernel
+ rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
+ # re-param conv bias
+ rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
+ rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
+ return rep_weight, rep_bias
+
+
+class ECB(nn.Module):
+ """The ECB block used in the ECBSR architecture.
+
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+ Ref git repo: https://github.com/xindongzhang/ECBSR
+
+ Args:
+ in_channels (int): Channel number of input.
+ out_channels (int): Channel number of output.
+ depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
+ act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
+ with_idt (bool): Whether to use identity connection. Default: False.
+ """
+
+ def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
+ super(ECB, self).__init__()
+
+ self.depth_multiplier = depth_multiplier
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.act_type = act_type
+
+ if with_idt and (self.in_channels == self.out_channels):
+ self.with_idt = True
+ else:
+ self.with_idt = False
+
+ self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
+ self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
+ self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
+ self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
+ self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
+
+ if self.act_type == 'prelu':
+ self.act = nn.PReLU(num_parameters=self.out_channels)
+ elif self.act_type == 'relu':
+ self.act = nn.ReLU(inplace=True)
+ elif self.act_type == 'rrelu':
+ self.act = nn.RReLU(lower=-0.05, upper=0.05)
+ elif self.act_type == 'softplus':
+ self.act = nn.Softplus()
+ elif self.act_type == 'linear':
+ pass
+ else:
+ raise ValueError('The type of activation if not support!')
+
+ def forward(self, x):
+ if self.training:
+ y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
+ if self.with_idt:
+ y += x
+ else:
+ rep_weight, rep_bias = self.rep_params()
+ y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
+ if self.act_type != 'linear':
+ y = self.act(y)
+ return y
+
+ def rep_params(self):
+ weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
+ weight1, bias1 = self.conv1x1_3x3.rep_params()
+ weight2, bias2 = self.conv1x1_sbx.rep_params()
+ weight3, bias3 = self.conv1x1_sby.rep_params()
+ weight4, bias4 = self.conv1x1_lpl.rep_params()
+ rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
+ bias0 + bias1 + bias2 + bias3 + bias4)
+
+ if self.with_idt:
+ device = rep_weight.get_device()
+ if device < 0:
+ device = None
+ weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
+ for i in range(self.out_channels):
+ weight_idt[i, i, 1, 1] = 1.0
+ bias_idt = 0.0
+ rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
+ return rep_weight, rep_bias
+
+
+@ARCH_REGISTRY.register()
+class ECBSR(nn.Module):
+ """ECBSR architecture.
+
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
+ Ref git repo: https://github.com/xindongzhang/ECBSR
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_block (int): Block number in the trunk network.
+ num_channel (int): Channel number.
+ with_idt (bool): Whether use identity in convolution layers.
+ act_type (str): Activation type.
+ scale (int): Upsampling factor.
+ """
+
+ def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
+ super(ECBSR, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.scale = scale
+
+ backbone = []
+ backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
+ for _ in range(num_block):
+ backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
+ backbone += [
+ ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
+ ]
+
+ self.backbone = nn.Sequential(*backbone)
+ self.upsampler = nn.PixelShuffle(scale)
+
+ def forward(self, x):
+ if self.num_in_ch > 1:
+ shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
+ else:
+ shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
+ y = self.backbone(x) + shortcut
+ y = self.upsampler(y)
+ return y
diff --git a/StableSR/basicsr/archs/edsr_arch.py b/StableSR/basicsr/archs/edsr_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b80566f11fbd4782d68eee8fbf7da686f89dc4e7
--- /dev/null
+++ b/StableSR/basicsr/archs/edsr_arch.py
@@ -0,0 +1,61 @@
+import torch
+from torch import nn as nn
+
+from basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register()
+class EDSR(nn.Module):
+ """EDSR network structure.
+
+ Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
+ Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ num_block (int): Block number in the trunk network. Default: 16.
+ upscale (int): Upsampling factor. Support 2^n and 3.
+ Default: 4.
+ res_scale (float): Used to scale the residual in residual block.
+ Default: 1.
+ img_range (float): Image range. Default: 255.
+ rgb_mean (tuple[float]): Image mean in RGB orders.
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
+ """
+
+ def __init__(self,
+ num_in_ch,
+ num_out_ch,
+ num_feat=64,
+ num_block=16,
+ upscale=4,
+ res_scale=1,
+ img_range=255.,
+ rgb_mean=(0.4488, 0.4371, 0.4040)):
+ super(EDSR, self).__init__()
+
+ self.img_range = img_range
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
+ self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ def forward(self, x):
+ self.mean = self.mean.type_as(x)
+
+ x = (x - self.mean) * self.img_range
+ x = self.conv_first(x)
+ res = self.conv_after_body(self.body(x))
+ res += x
+
+ x = self.conv_last(self.upsample(res))
+ x = x / self.img_range + self.mean
+
+ return x
diff --git a/StableSR/basicsr/archs/edvr_arch.py b/StableSR/basicsr/archs/edvr_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0c4f47deb383d4fe6108b97436c9dfb1e541583
--- /dev/null
+++ b/StableSR/basicsr/archs/edvr_arch.py
@@ -0,0 +1,382 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
+
+
+class PCDAlignment(nn.Module):
+ """Alignment module using Pyramid, Cascading and Deformable convolution
+ (PCD). It is used in EDVR.
+
+ ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
+
+ Args:
+ num_feat (int): Channel number of middle features. Default: 64.
+ deformable_groups (int): Deformable groups. Defaults: 8.
+ """
+
+ def __init__(self, num_feat=64, deformable_groups=8):
+ super(PCDAlignment, self).__init__()
+
+ # Pyramid has three levels:
+ # L3: level 3, 1/4 spatial size
+ # L2: level 2, 1/2 spatial size
+ # L1: level 1, original spatial size
+ self.offset_conv1 = nn.ModuleDict()
+ self.offset_conv2 = nn.ModuleDict()
+ self.offset_conv3 = nn.ModuleDict()
+ self.dcn_pack = nn.ModuleDict()
+ self.feat_conv = nn.ModuleDict()
+
+ # Pyramids
+ for i in range(3, 0, -1):
+ level = f'l{i}'
+ self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+ if i == 3:
+ self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ else:
+ self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+ self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
+
+ if i < 3:
+ self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+
+ # Cascading dcn
+ self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+ self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
+
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, nbr_feat_l, ref_feat_l):
+ """Align neighboring frame features to the reference frame features.
+
+ Args:
+ nbr_feat_l (list[Tensor]): Neighboring feature list. It
+ contains three pyramid levels (L1, L2, L3),
+ each with shape (b, c, h, w).
+ ref_feat_l (list[Tensor]): Reference feature list. It
+ contains three pyramid levels (L1, L2, L3),
+ each with shape (b, c, h, w).
+
+ Returns:
+ Tensor: Aligned features.
+ """
+ # Pyramids
+ upsampled_offset, upsampled_feat = None, None
+ for i in range(3, 0, -1):
+ level = f'l{i}'
+ offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
+ offset = self.lrelu(self.offset_conv1[level](offset))
+ if i == 3:
+ offset = self.lrelu(self.offset_conv2[level](offset))
+ else:
+ offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
+ offset = self.lrelu(self.offset_conv3[level](offset))
+
+ feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
+ if i < 3:
+ feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
+ if i > 1:
+ feat = self.lrelu(feat)
+
+ if i > 1: # upsample offset and features
+ # x2: when we upsample the offset, we should also enlarge
+ # the magnitude.
+ upsampled_offset = self.upsample(offset) * 2
+ upsampled_feat = self.upsample(feat)
+
+ # Cascading
+ offset = torch.cat([feat, ref_feat_l[0]], dim=1)
+ offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
+ feat = self.lrelu(self.cas_dcnpack(feat, offset))
+ return feat
+
+
+class TSAFusion(nn.Module):
+ """Temporal Spatial Attention (TSA) fusion module.
+
+ Temporal: Calculate the correlation between center frame and
+ neighboring frames;
+ Spatial: It has 3 pyramid levels, the attention is similar to SFT.
+ (SFT: Recovering realistic texture in image super-resolution by deep
+ spatial feature transform.)
+
+ Args:
+ num_feat (int): Channel number of middle features. Default: 64.
+ num_frame (int): Number of frames. Default: 5.
+ center_frame_idx (int): The index of center frame. Default: 2.
+ """
+
+ def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
+ super(TSAFusion, self).__init__()
+ self.center_frame_idx = center_frame_idx
+ # temporal attention (before fusion conv)
+ self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
+
+ # spatial attention (after fusion conv)
+ self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
+ self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
+ self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
+ self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
+ self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
+ self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
+ self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
+ self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
+ self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+
+ def forward(self, aligned_feat):
+ """
+ Args:
+ aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
+
+ Returns:
+ Tensor: Features after TSA with the shape (b, c, h, w).
+ """
+ b, t, c, h, w = aligned_feat.size()
+ # temporal attention
+ embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
+ embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
+ embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
+
+ corr_l = [] # correlation list
+ for i in range(t):
+ emb_neighbor = embedding[:, i, :, :, :]
+ corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
+ corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
+ corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
+ corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
+ corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
+ aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
+
+ # fusion
+ feat = self.lrelu(self.feat_fusion(aligned_feat))
+
+ # spatial attention
+ attn = self.lrelu(self.spatial_attn1(aligned_feat))
+ attn_max = self.max_pool(attn)
+ attn_avg = self.avg_pool(attn)
+ attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
+ # pyramid levels
+ attn_level = self.lrelu(self.spatial_attn_l1(attn))
+ attn_max = self.max_pool(attn_level)
+ attn_avg = self.avg_pool(attn_level)
+ attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
+ attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
+ attn_level = self.upsample(attn_level)
+
+ attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
+ attn = self.lrelu(self.spatial_attn4(attn))
+ attn = self.upsample(attn)
+ attn = self.spatial_attn5(attn)
+ attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
+ attn = torch.sigmoid(attn)
+
+ # after initialization, * 2 makes (attn * 2) to be close to 1.
+ feat = feat * attn * 2 + attn_add
+ return feat
+
+
+class PredeblurModule(nn.Module):
+ """Pre-dublur module.
+
+ Args:
+ num_in_ch (int): Channel number of input image. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ hr_in (bool): Whether the input has high resolution. Default: False.
+ """
+
+ def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
+ super(PredeblurModule, self).__init__()
+ self.hr_in = hr_in
+
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ if self.hr_in:
+ # downsample x4 by stride conv
+ self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+
+ # generate feature pyramid
+ self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+
+ self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
+ self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
+ self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
+ self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
+
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, x):
+ feat_l1 = self.lrelu(self.conv_first(x))
+ if self.hr_in:
+ feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
+ feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
+
+ # generate feature pyramid
+ feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
+ feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
+
+ feat_l3 = self.upsample(self.resblock_l3(feat_l3))
+ feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
+ feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
+
+ for i in range(2):
+ feat_l1 = self.resblock_l1[i](feat_l1)
+ feat_l1 = feat_l1 + feat_l2
+ for i in range(2, 5):
+ feat_l1 = self.resblock_l1[i](feat_l1)
+ return feat_l1
+
+
+@ARCH_REGISTRY.register()
+class EDVR(nn.Module):
+ """EDVR network structure for video super-resolution.
+
+ Now only support X4 upsampling factor.
+
+ ``Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks``
+
+ Args:
+ num_in_ch (int): Channel number of input image. Default: 3.
+ num_out_ch (int): Channel number of output image. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ num_frame (int): Number of input frames. Default: 5.
+ deformable_groups (int): Deformable groups. Defaults: 8.
+ num_extract_block (int): Number of blocks for feature extraction.
+ Default: 5.
+ num_reconstruct_block (int): Number of blocks for reconstruction.
+ Default: 10.
+ center_frame_idx (int): The index of center frame. Frame counting from
+ 0. Default: Middle of input frames.
+ hr_in (bool): Whether the input has high resolution. Default: False.
+ with_predeblur (bool): Whether has predeblur module.
+ Default: False.
+ with_tsa (bool): Whether has TSA module. Default: True.
+ """
+
+ def __init__(self,
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_frame=5,
+ deformable_groups=8,
+ num_extract_block=5,
+ num_reconstruct_block=10,
+ center_frame_idx=None,
+ hr_in=False,
+ with_predeblur=False,
+ with_tsa=True):
+ super(EDVR, self).__init__()
+ if center_frame_idx is None:
+ self.center_frame_idx = num_frame // 2
+ else:
+ self.center_frame_idx = center_frame_idx
+ self.hr_in = hr_in
+ self.with_predeblur = with_predeblur
+ self.with_tsa = with_tsa
+
+ # extract features for each frame
+ if self.with_predeblur:
+ self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
+ self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
+ else:
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+
+ # extract pyramid features
+ self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
+ self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
+ self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+
+ # pcd and tsa module
+ self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
+ if self.with_tsa:
+ self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
+ else:
+ self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
+
+ # reconstruction
+ self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
+ # upsample
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
+ self.pixel_shuffle = nn.PixelShuffle(2)
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ def forward(self, x):
+ b, t, c, h, w = x.size()
+ if self.hr_in:
+ assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
+ else:
+ assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
+
+ x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
+
+ # extract features for each frame
+ # L1
+ if self.with_predeblur:
+ feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
+ if self.hr_in:
+ h, w = h // 4, w // 4
+ else:
+ feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
+
+ feat_l1 = self.feature_extraction(feat_l1)
+ # L2
+ feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
+ feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
+ # L3
+ feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
+ feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
+
+ feat_l1 = feat_l1.view(b, t, -1, h, w)
+ feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
+ feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
+
+ # PCD alignment
+ ref_feat_l = [ # reference feature list
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
+ feat_l3[:, self.center_frame_idx, :, :, :].clone()
+ ]
+ aligned_feat = []
+ for i in range(t):
+ nbr_feat_l = [ # neighboring feature list
+ feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
+ ]
+ aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
+ aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
+
+ if not self.with_tsa:
+ aligned_feat = aligned_feat.view(b, -1, h, w)
+ feat = self.fusion(aligned_feat)
+
+ out = self.reconstruction(feat)
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ out = self.lrelu(self.conv_hr(out))
+ out = self.conv_last(out)
+ if self.hr_in:
+ base = x_center
+ else:
+ base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
+ out += base
+ return out
diff --git a/StableSR/basicsr/archs/hifacegan_arch.py b/StableSR/basicsr/archs/hifacegan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..098e3ed4306eb19ae9da705c0af580a6f74c6cb9
--- /dev/null
+++ b/StableSR/basicsr/archs/hifacegan_arch.py
@@ -0,0 +1,260 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
+
+
+class SPADEGenerator(BaseNetwork):
+ """Generator with SPADEResBlock"""
+
+ def __init__(self,
+ num_in_ch=3,
+ num_feat=64,
+ use_vae=False,
+ z_dim=256,
+ crop_size=512,
+ norm_g='spectralspadesyncbatch3x3',
+ is_train=True,
+ init_train_phase=3): # progressive training disabled
+ super().__init__()
+ self.nf = num_feat
+ self.input_nc = num_in_ch
+ self.is_train = is_train
+ self.train_phase = init_train_phase
+
+ self.scale_ratio = 5 # hardcoded now
+ self.sw = crop_size // (2**self.scale_ratio)
+ self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
+
+ if use_vae:
+ # In case of VAE, we will sample from random z vector
+ self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
+ else:
+ # Otherwise, we make the network deterministic by starting with
+ # downsampled segmentation map instead of random z
+ self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
+
+ self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
+
+ self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
+ self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
+
+ self.ups = nn.ModuleList([
+ SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
+ SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
+ SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
+ SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
+ ])
+
+ self.to_rgbs = nn.ModuleList([
+ nn.Conv2d(8 * self.nf, 3, 3, padding=1),
+ nn.Conv2d(4 * self.nf, 3, 3, padding=1),
+ nn.Conv2d(2 * self.nf, 3, 3, padding=1),
+ nn.Conv2d(1 * self.nf, 3, 3, padding=1)
+ ])
+
+ self.up = nn.Upsample(scale_factor=2)
+
+ def encode(self, input_tensor):
+ """
+ Encode input_tensor into feature maps, can be overridden in derived classes
+ Default: nearest downsampling of 2**5 = 32 times
+ """
+ h, w = input_tensor.size()[-2:]
+ sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
+ x = F.interpolate(input_tensor, size=(sh, sw))
+ return self.fc(x)
+
+ def forward(self, x):
+ # In oroginal SPADE, seg means a segmentation map, but here we use x instead.
+ seg = x
+
+ x = self.encode(x)
+ x = self.head_0(x, seg)
+
+ x = self.up(x)
+ x = self.g_middle_0(x, seg)
+ x = self.g_middle_1(x, seg)
+
+ if self.is_train:
+ phase = self.train_phase + 1
+ else:
+ phase = len(self.to_rgbs)
+
+ for i in range(phase):
+ x = self.up(x)
+ x = self.ups[i](x, seg)
+
+ x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
+ x = torch.tanh(x)
+
+ return x
+
+ def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
+ """
+ A helper class for subspace visualization. Input and seg are different images.
+ For the first n levels (including encoder) we use input, for the rest we use seg.
+
+ If mode = 'progressive', the output's like: AAABBB
+ If mode = 'one_plug', the output's like: AAABAA
+ If mode = 'one_ablate', the output's like: BBBABB
+ """
+
+ if seg is None:
+ return self.forward(input_x)
+
+ if self.is_train:
+ phase = self.train_phase + 1
+ else:
+ phase = len(self.to_rgbs)
+
+ if mode == 'progressive':
+ n = max(min(n, 4 + phase), 0)
+ guide_list = [input_x] * n + [seg] * (4 + phase - n)
+ elif mode == 'one_plug':
+ n = max(min(n, 4 + phase - 1), 0)
+ guide_list = [seg] * (4 + phase)
+ guide_list[n] = input_x
+ elif mode == 'one_ablate':
+ if n > 3 + phase:
+ return self.forward(input_x)
+ guide_list = [input_x] * (4 + phase)
+ guide_list[n] = seg
+
+ x = self.encode(guide_list[0])
+ x = self.head_0(x, guide_list[1])
+
+ x = self.up(x)
+ x = self.g_middle_0(x, guide_list[2])
+ x = self.g_middle_1(x, guide_list[3])
+
+ for i in range(phase):
+ x = self.up(x)
+ x = self.ups[i](x, guide_list[4 + i])
+
+ x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
+ x = torch.tanh(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class HiFaceGAN(SPADEGenerator):
+ """
+ HiFaceGAN: SPADEGenerator with a learnable feature encoder
+ Current encoder design: LIPEncoder
+ """
+
+ def __init__(self,
+ num_in_ch=3,
+ num_feat=64,
+ use_vae=False,
+ z_dim=256,
+ crop_size=512,
+ norm_g='spectralspadesyncbatch3x3',
+ is_train=True,
+ init_train_phase=3):
+ super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
+ self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
+
+ def encode(self, input_tensor):
+ return self.lip_encoder(input_tensor)
+
+
+@ARCH_REGISTRY.register()
+class HiFaceGANDiscriminator(BaseNetwork):
+ """
+ Inspired by pix2pixHD multiscale discriminator.
+
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_out_ch (int): Channel number of outputs. Default: 3.
+ conditional_d (bool): Whether use conditional discriminator.
+ Default: True.
+ num_d (int): Number of Multiscale discriminators. Default: 3.
+ n_layers_d (int): Number of downsample layers in each D. Default: 4.
+ num_feat (int): Channel number of base intermediate features.
+ Default: 64.
+ norm_d (str): String to determine normalization layers in D.
+ Choices: [spectral][instance/batch/syncbatch]
+ Default: 'spectralinstance'.
+ keep_features (bool): Keep intermediate features for matching loss, etc.
+ Default: True.
+ """
+
+ def __init__(self,
+ num_in_ch=3,
+ num_out_ch=3,
+ conditional_d=True,
+ num_d=2,
+ n_layers_d=4,
+ num_feat=64,
+ norm_d='spectralinstance',
+ keep_features=True):
+ super().__init__()
+ self.num_d = num_d
+
+ input_nc = num_in_ch
+ if conditional_d:
+ input_nc += num_out_ch
+
+ for i in range(num_d):
+ subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
+ self.add_module(f'discriminator_{i}', subnet_d)
+
+ def downsample(self, x):
+ return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
+
+ # Returns list of lists of discriminator outputs.
+ # The final result is of size opt.num_d x opt.n_layers_D
+ def forward(self, x):
+ result = []
+ for _, _net_d in self.named_children():
+ out = _net_d(x)
+ result.append(out)
+ x = self.downsample(x)
+
+ return result
+
+
+class NLayerDiscriminator(BaseNetwork):
+ """Defines the PatchGAN discriminator with the specified arguments."""
+
+ def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
+ super().__init__()
+ kw = 4
+ padw = int(np.ceil((kw - 1.0) / 2))
+ nf = num_feat
+ self.keep_features = keep_features
+
+ norm_layer = get_nonspade_norm_layer(norm_d)
+ sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
+
+ for n in range(1, n_layers_d):
+ nf_prev = nf
+ nf = min(nf * 2, 512)
+ stride = 1 if n == n_layers_d - 1 else 2
+ sequence += [[
+ norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
+ nn.LeakyReLU(0.2, False)
+ ]]
+
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
+
+ # We divide the layers into groups to extract intermediate layer outputs
+ for n in range(len(sequence)):
+ self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
+
+ def forward(self, x):
+ results = [x]
+ for submodel in self.children():
+ intermediate_output = submodel(results[-1])
+ results.append(intermediate_output)
+
+ if self.keep_features:
+ return results[1:]
+ else:
+ return results[-1]
diff --git a/StableSR/basicsr/archs/hifacegan_util.py b/StableSR/basicsr/archs/hifacegan_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..35cbef3f532fcc6aab0fa57ab316a546d3a17bd5
--- /dev/null
+++ b/StableSR/basicsr/archs/hifacegan_util.py
@@ -0,0 +1,255 @@
+import re
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+# Warning: spectral norm could be buggy
+# under eval mode and multi-GPU inference
+# A workaround is sticking to single-GPU inference and train mode
+from torch.nn.utils import spectral_norm
+
+
+class SPADE(nn.Module):
+
+ def __init__(self, config_text, norm_nc, label_nc):
+ super().__init__()
+
+ assert config_text.startswith('spade')
+ parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
+ param_free_norm_type = str(parsed.group(1))
+ ks = int(parsed.group(2))
+
+ if param_free_norm_type == 'instance':
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
+ elif param_free_norm_type == 'syncbatch':
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
+ elif param_free_norm_type == 'batch':
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
+ else:
+ raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
+
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
+ nhidden = 128 if norm_nc > 128 else norm_nc
+
+ pw = ks // 2
+ self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
+
+ def forward(self, x, segmap):
+
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x)
+
+ # Part 2. produce scaling and bias conditioned on semantic map
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+
+ # apply scale and bias
+ out = normalized * gamma + beta
+
+ return out
+
+
+class SPADEResnetBlock(nn.Module):
+ """
+ ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
+ it takes in the segmentation map as input, learns the skip connection if necessary,
+ and applies normalization first and then convolution.
+ This architecture seemed like a standard architecture for unconditional or
+ class-conditional GAN architecture using residual block.
+ The code was inspired from https://github.com/LMescheder/GAN_stability.
+ """
+
+ def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
+ super().__init__()
+ # Attributes
+ self.learned_shortcut = (fin != fout)
+ fmiddle = min(fin, fout)
+
+ # create conv layers
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+
+ # apply spectral norm if specified
+ if 'spectral' in norm_g:
+ self.conv_0 = spectral_norm(self.conv_0)
+ self.conv_1 = spectral_norm(self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = spectral_norm(self.conv_s)
+
+ # define normalization layers
+ spade_config_str = norm_g.replace('spectral', '')
+ self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
+ self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
+
+ # note the resnet block with SPADE also takes in |seg|,
+ # the semantic segmentation map as input
+ def forward(self, x, seg):
+ x_s = self.shortcut(x, seg)
+ dx = self.conv_0(self.act(self.norm_0(x, seg)))
+ dx = self.conv_1(self.act(self.norm_1(dx, seg)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, seg):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg))
+ else:
+ x_s = x
+ return x_s
+
+ def act(self, x):
+ return F.leaky_relu(x, 2e-1)
+
+
+class BaseNetwork(nn.Module):
+ """ A basis for hifacegan archs with custom initialization """
+
+ def init_weights(self, init_type='normal', gain=0.02):
+
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find('BatchNorm2d') != -1:
+ if hasattr(m, 'weight') and m.weight is not None:
+ init.normal_(m.weight.data, 1.0, gain)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'xavier_uniform':
+ init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=gain)
+ elif init_type == 'none': # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+
+ self.apply(init_func)
+
+ # propagate to children
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights(init_type, gain)
+
+ def forward(self, x):
+ pass
+
+
+def lip2d(x, logit, kernel=3, stride=2, padding=1):
+ weight = logit.exp()
+ return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
+
+
+class SoftGate(nn.Module):
+ COEFF = 12.0
+
+ def forward(self, x):
+ return torch.sigmoid(x).mul(self.COEFF)
+
+
+class SimplifiedLIP(nn.Module):
+
+ def __init__(self, channels):
+ super(SimplifiedLIP, self).__init__()
+ self.logit = nn.Sequential(
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
+ SoftGate())
+
+ def init_layer(self):
+ self.logit[0].weight.data.fill_(0.0)
+
+ def forward(self, x):
+ frac = lip2d(x, self.logit(x))
+ return frac
+
+
+class LIPEncoder(BaseNetwork):
+ """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
+
+ def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
+ super().__init__()
+ self.sw = sw
+ self.sh = sh
+ self.max_ratio = 16
+ # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
+ kw = 3
+ pw = (kw - 1) // 2
+
+ model = [
+ nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
+ norm_layer(ngf),
+ nn.ReLU(),
+ ]
+ cur_ratio = 1
+ for i in range(n_2xdown):
+ next_ratio = min(cur_ratio * 2, self.max_ratio)
+ model += [
+ SimplifiedLIP(ngf * cur_ratio),
+ nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
+ norm_layer(ngf * next_ratio),
+ ]
+ cur_ratio = next_ratio
+ if i < n_2xdown - 1:
+ model += [nn.ReLU(inplace=True)]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+def get_nonspade_norm_layer(norm_type='instance'):
+ # helper function to get # output channels of the previous layer
+ def get_out_channel(layer):
+ if hasattr(layer, 'out_channels'):
+ return getattr(layer, 'out_channels')
+ return layer.weight.size(0)
+
+ # this function will be returned
+ def add_norm_layer(layer):
+ nonlocal norm_type
+ if norm_type.startswith('spectral'):
+ layer = spectral_norm(layer)
+ subnorm_type = norm_type[len('spectral'):]
+
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
+ return layer
+
+ # remove bias in the previous layer, which is meaningless
+ # since it has no effect after normalization
+ if getattr(layer, 'bias', None) is not None:
+ delattr(layer, 'bias')
+ layer.register_parameter('bias', None)
+
+ if subnorm_type == 'batch':
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
+ elif subnorm_type == 'sync_batch':
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
+ # norm_layer = SynchronizedBatchNorm2d(
+ # get_out_channel(layer), affine=True)
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+ elif subnorm_type == 'instance':
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+ else:
+ raise ValueError(f'normalization layer {subnorm_type} is not recognized')
+
+ return nn.Sequential(layer, norm_layer)
+
+ print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
+ return add_norm_layer
diff --git a/StableSR/basicsr/archs/inception.py b/StableSR/basicsr/archs/inception.py
new file mode 100644
index 0000000000000000000000000000000000000000..de1abef67270dc1aba770943b53577029141f527
--- /dev/null
+++ b/StableSR/basicsr/archs/inception.py
@@ -0,0 +1,307 @@
+# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
+# For FID metric
+
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.model_zoo import load_url
+from torchvision import models
+
+# Inception weights ported to Pytorch from
+# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
+LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
+
+
+class InceptionV3(nn.Module):
+ """Pretrained InceptionV3 network returning feature maps"""
+
+ # Index of default block of inception to return,
+ # corresponds to output of final average pooling
+ DEFAULT_BLOCK_INDEX = 3
+
+ # Maps feature dimensionality to their output blocks indices
+ BLOCK_INDEX_BY_DIM = {
+ 64: 0, # First max pooling features
+ 192: 1, # Second max pooling features
+ 768: 2, # Pre-aux classifier features
+ 2048: 3 # Final average pooling features
+ }
+
+ def __init__(self,
+ output_blocks=(DEFAULT_BLOCK_INDEX),
+ resize_input=True,
+ normalize_input=True,
+ requires_grad=False,
+ use_fid_inception=True):
+ """Build pretrained InceptionV3.
+
+ Args:
+ output_blocks (list[int]): Indices of blocks to return features of.
+ Possible values are:
+ - 0: corresponds to output of first max pooling
+ - 1: corresponds to output of second max pooling
+ - 2: corresponds to output which is fed to aux classifier
+ - 3: corresponds to output of final average pooling
+ resize_input (bool): If true, bilinearly resizes input to width and
+ height 299 before feeding input to model. As the network
+ without fully connected layers is fully convolutional, it
+ should be able to handle inputs of arbitrary size, so resizing
+ might not be strictly needed. Default: True.
+ normalize_input (bool): If true, scales the input from range (0, 1)
+ to the range the pretrained Inception network expects,
+ namely (-1, 1). Default: True.
+ requires_grad (bool): If true, parameters of the model require
+ gradients. Possibly useful for finetuning the network.
+ Default: False.
+ use_fid_inception (bool): If true, uses the pretrained Inception
+ model used in Tensorflow's FID implementation.
+ If false, uses the pretrained Inception model available in
+ torchvision. The FID Inception model has different weights
+ and a slightly different structure from torchvision's
+ Inception model. If you want to compute FID scores, you are
+ strongly advised to set this parameter to true to get
+ comparable results. Default: True.
+ """
+ super(InceptionV3, self).__init__()
+
+ self.resize_input = resize_input
+ self.normalize_input = normalize_input
+ self.output_blocks = sorted(output_blocks)
+ self.last_needed_block = max(output_blocks)
+
+ assert self.last_needed_block <= 3, ('Last possible output block index is 3')
+
+ self.blocks = nn.ModuleList()
+
+ if use_fid_inception:
+ inception = fid_inception_v3()
+ else:
+ try:
+ inception = models.inception_v3(pretrained=True, init_weights=False)
+ except TypeError:
+ # pytorch < 1.5 does not have init_weights for inception_v3
+ inception = models.inception_v3(pretrained=True)
+
+ # Block 0: input to maxpool1
+ block0 = [
+ inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2)
+ ]
+ self.blocks.append(nn.Sequential(*block0))
+
+ # Block 1: maxpool1 to maxpool2
+ if self.last_needed_block >= 1:
+ block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
+ self.blocks.append(nn.Sequential(*block1))
+
+ # Block 2: maxpool2 to aux classifier
+ if self.last_needed_block >= 2:
+ block2 = [
+ inception.Mixed_5b,
+ inception.Mixed_5c,
+ inception.Mixed_5d,
+ inception.Mixed_6a,
+ inception.Mixed_6b,
+ inception.Mixed_6c,
+ inception.Mixed_6d,
+ inception.Mixed_6e,
+ ]
+ self.blocks.append(nn.Sequential(*block2))
+
+ # Block 3: aux classifier to final avgpool
+ if self.last_needed_block >= 3:
+ block3 = [
+ inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
+ ]
+ self.blocks.append(nn.Sequential(*block3))
+
+ for param in self.parameters():
+ param.requires_grad = requires_grad
+
+ def forward(self, x):
+ """Get Inception feature maps.
+
+ Args:
+ x (Tensor): Input tensor of shape (b, 3, h, w).
+ Values are expected to be in range (-1, 1). You can also input
+ (0, 1) with setting normalize_input = True.
+
+ Returns:
+ list[Tensor]: Corresponding to the selected output block, sorted
+ ascending by index.
+ """
+ output = []
+
+ if self.resize_input:
+ x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
+
+ if self.normalize_input:
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
+
+ for idx, block in enumerate(self.blocks):
+ x = block(x)
+ if idx in self.output_blocks:
+ output.append(x)
+
+ if idx == self.last_needed_block:
+ break
+
+ return output
+
+
+def fid_inception_v3():
+ """Build pretrained Inception model for FID computation.
+
+ The Inception model for FID computation uses a different set of weights
+ and has a slightly different structure than torchvision's Inception.
+
+ This method first constructs torchvision's Inception and then patches the
+ necessary parts that are different in the FID Inception model.
+ """
+ try:
+ inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
+ except TypeError:
+ # pytorch < 1.5 does not have init_weights for inception_v3
+ inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
+
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+ inception.Mixed_7b = FIDInceptionE_1(1280)
+ inception.Mixed_7c = FIDInceptionE_2(2048)
+
+ if os.path.exists(LOCAL_FID_WEIGHTS):
+ state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
+ else:
+ state_dict = load_url(FID_WEIGHTS_URL, progress=True)
+
+ inception.load_state_dict(state_dict)
+ return inception
+
+
+class FIDInceptionA(models.inception.InceptionA):
+ """InceptionA block patched for FID computation"""
+
+ def __init__(self, in_channels, pool_features):
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(models.inception.InceptionC):
+ """InceptionC block patched for FID computation"""
+
+ def __init__(self, in_channels, channels_7x7):
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(models.inception.InceptionE):
+ """First InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super(FIDInceptionE_1, self).__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(models.inception.InceptionE):
+ """Second InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super(FIDInceptionE_2, self).__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: The FID Inception model uses max pooling instead of average
+ # pooling. This is likely an error in this specific Inception
+ # implementation, as other Inception models use average pooling here
+ # (which matches the description in the paper).
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
diff --git a/StableSR/basicsr/archs/rcan_arch.py b/StableSR/basicsr/archs/rcan_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..48872e6800006d885f56f90dd2f0a2bd16e513d9
--- /dev/null
+++ b/StableSR/basicsr/archs/rcan_arch.py
@@ -0,0 +1,135 @@
+import torch
+from torch import nn as nn
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import Upsample, make_layer
+
+
+class ChannelAttention(nn.Module):
+ """Channel attention used in RCAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
+ """
+
+ def __init__(self, num_feat, squeeze_factor=16):
+ super(ChannelAttention, self).__init__()
+ self.attention = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
+ nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
+
+ def forward(self, x):
+ y = self.attention(x)
+ return x * y
+
+
+class RCAB(nn.Module):
+ """Residual Channel Attention Block (RCAB) used in RCAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
+ res_scale (float): Scale the residual. Default: 1.
+ """
+
+ def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
+ super(RCAB, self).__init__()
+ self.res_scale = res_scale
+
+ self.rcab = nn.Sequential(
+ nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
+ ChannelAttention(num_feat, squeeze_factor))
+
+ def forward(self, x):
+ res = self.rcab(x) * self.res_scale
+ return res + x
+
+
+class ResidualGroup(nn.Module):
+ """Residual Group of RCAB.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_block (int): Block number in the body network.
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
+ res_scale (float): Scale the residual. Default: 1.
+ """
+
+ def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
+ super(ResidualGroup, self).__init__()
+
+ self.residual_group = make_layer(
+ RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
+ self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+
+ def forward(self, x):
+ res = self.conv(self.residual_group(x))
+ return res + x
+
+
+@ARCH_REGISTRY.register()
+class RCAN(nn.Module):
+ """Residual Channel Attention Networks.
+
+ ``Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks``
+
+ Reference: https://github.com/yulunzhang/RCAN
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ num_group (int): Number of ResidualGroup. Default: 10.
+ num_block (int): Number of RCAB in ResidualGroup. Default: 16.
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
+ upscale (int): Upsampling factor. Support 2^n and 3.
+ Default: 4.
+ res_scale (float): Used to scale the residual in residual block.
+ Default: 1.
+ img_range (float): Image range. Default: 255.
+ rgb_mean (tuple[float]): Image mean in RGB orders.
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
+ """
+
+ def __init__(self,
+ num_in_ch,
+ num_out_ch,
+ num_feat=64,
+ num_group=10,
+ num_block=16,
+ squeeze_factor=16,
+ upscale=4,
+ res_scale=1,
+ img_range=255.,
+ rgb_mean=(0.4488, 0.4371, 0.4040)):
+ super(RCAN, self).__init__()
+
+ self.img_range = img_range
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(
+ ResidualGroup,
+ num_group,
+ num_feat=num_feat,
+ num_block=num_block,
+ squeeze_factor=squeeze_factor,
+ res_scale=res_scale)
+ self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ def forward(self, x):
+ self.mean = self.mean.type_as(x)
+
+ x = (x - self.mean) * self.img_range
+ x = self.conv_first(x)
+ res = self.conv_after_body(self.body(x))
+ res += x
+
+ x = self.conv_last(self.upsample(res))
+ x = x / self.img_range + self.mean
+
+ return x
diff --git a/StableSR/basicsr/archs/ridnet_arch.py b/StableSR/basicsr/archs/ridnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..85bb9ae0348e27dd6c797c03f8d9ec43f8b0b829
--- /dev/null
+++ b/StableSR/basicsr/archs/ridnet_arch.py
@@ -0,0 +1,180 @@
+import torch
+import torch.nn as nn
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, make_layer
+
+
+class MeanShift(nn.Conv2d):
+ """ Data normalization with mean and std.
+
+ Args:
+ rgb_range (int): Maximum value of RGB.
+ rgb_mean (list[float]): Mean for RGB channels.
+ rgb_std (list[float]): Std for RGB channels.
+ sign (int): For subtraction, sign is -1, for addition, sign is 1.
+ Default: -1.
+ requires_grad (bool): Whether to update the self.weight and self.bias.
+ Default: True.
+ """
+
+ def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
+ std = torch.Tensor(rgb_std)
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1)
+ self.weight.data.div_(std.view(3, 1, 1, 1))
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
+ self.bias.data.div_(std)
+ self.requires_grad = requires_grad
+
+
+class EResidualBlockNoBN(nn.Module):
+ """Enhanced Residual block without BN.
+
+ There are three convolution layers in residual branch.
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super(EResidualBlockNoBN, self).__init__()
+
+ self.body = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, 1, 1, 0),
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ out = self.body(x)
+ out = self.relu(out + x)
+ return out
+
+
+class MergeRun(nn.Module):
+ """ Merge-and-run unit.
+
+ This unit contains two branches with different dilated convolutions,
+ followed by a convolution to process the concatenated features.
+
+ Paper: Real Image Denoising with Feature Attention
+ Ref git repo: https://github.com/saeed-anwar/RIDNet
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
+ super(MergeRun, self).__init__()
+
+ self.dilation1 = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
+ self.dilation2 = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
+
+ self.aggregation = nn.Sequential(
+ nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
+
+ def forward(self, x):
+ dilation1 = self.dilation1(x)
+ dilation2 = self.dilation2(x)
+ out = torch.cat([dilation1, dilation2], dim=1)
+ out = self.aggregation(out)
+ out = out + x
+ return out
+
+
+class ChannelAttention(nn.Module):
+ """Channel attention.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ squeeze_factor (int): Channel squeeze factor. Default:
+ """
+
+ def __init__(self, mid_channels, squeeze_factor=16):
+ super(ChannelAttention, self).__init__()
+ self.attention = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
+ nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
+
+ def forward(self, x):
+ y = self.attention(x)
+ return x * y
+
+
+class EAM(nn.Module):
+ """Enhancement attention modules (EAM) in RIDNet.
+
+ This module contains a merge-and-run unit, a residual block,
+ an enhanced residual block and a feature attention unit.
+
+ Attributes:
+ merge: The merge-and-run unit.
+ block1: The residual block.
+ block2: The enhanced residual block.
+ ca: The feature/channel attention unit.
+ """
+
+ def __init__(self, in_channels, mid_channels, out_channels):
+ super(EAM, self).__init__()
+
+ self.merge = MergeRun(in_channels, mid_channels)
+ self.block1 = ResidualBlockNoBN(mid_channels)
+ self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
+ self.ca = ChannelAttention(out_channels)
+ # The residual block in the paper contains a relu after addition.
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ out = self.merge(x)
+ out = self.relu(self.block1(out))
+ out = self.block2(out)
+ out = self.ca(out)
+ return out
+
+
+@ARCH_REGISTRY.register()
+class RIDNet(nn.Module):
+ """RIDNet: Real Image Denoising with Feature Attention.
+
+ Ref git repo: https://github.com/saeed-anwar/RIDNet
+
+ Args:
+ in_channels (int): Channel number of inputs.
+ mid_channels (int): Channel number of EAM modules.
+ Default: 64.
+ out_channels (int): Channel number of outputs.
+ num_block (int): Number of EAM. Default: 4.
+ img_range (float): Image range. Default: 255.
+ rgb_mean (tuple[float]): Image mean in RGB orders.
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
+ """
+
+ def __init__(self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ num_block=4,
+ img_range=255.,
+ rgb_mean=(0.4488, 0.4371, 0.4040),
+ rgb_std=(1.0, 1.0, 1.0)):
+ super(RIDNet, self).__init__()
+
+ self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
+ self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
+
+ self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
+ self.body = make_layer(
+ EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
+ self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ res = self.sub_mean(x)
+ res = self.tail(self.body(self.relu(self.head(res))))
+ res = self.add_mean(res)
+
+ out = x + res
+ return out
diff --git a/StableSR/basicsr/archs/rrdbnet_arch.py b/StableSR/basicsr/archs/rrdbnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..63d07080c2ec1305090c59b7bfbbda2b003b18e4
--- /dev/null
+++ b/StableSR/basicsr/archs/rrdbnet_arch.py
@@ -0,0 +1,119 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import default_init_weights, make_layer, pixel_unshuffle
+
+
+class ResidualDenseBlock(nn.Module):
+ """Residual Dense Block.
+
+ Used in RRDB block in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat=64, num_grow_ch=32):
+ super(ResidualDenseBlock, self).__init__()
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ # Empirically, we use 0.2 to scale the residual for better performance
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ """Residual in Residual Dense Block.
+
+ Used in RRDB-Net in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat, num_grow_ch=32):
+ super(RRDB, self).__init__()
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+ def forward(self, x):
+ out = self.rdb1(x)
+ out = self.rdb2(out)
+ out = self.rdb3(out)
+ # Empirically, we use 0.2 to scale the residual for better performance
+ return out * 0.2 + x
+
+
+@ARCH_REGISTRY.register()
+class RRDBNet(nn.Module):
+ """Networks consisting of Residual in Residual Dense Block, which is used
+ in ESRGAN.
+
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
+
+ We extend ESRGAN for scale x2 and scale x1.
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64
+ num_block (int): Block number in the trunk network. Defaults: 23
+ num_grow_ch (int): Channels for each growth. Default: 32.
+ """
+
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
+ super(RRDBNet, self).__init__()
+ self.scale = scale
+ if scale == 2:
+ num_in_ch = num_in_ch * 4
+ elif scale == 1:
+ num_in_ch = num_in_ch * 16
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ # upsample
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ if self.scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ elif self.scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ else:
+ feat = x
+ feat = self.conv_first(feat)
+ body_feat = self.conv_body(self.body(feat))
+ feat = feat + body_feat
+ # upsample
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+ return out
diff --git a/StableSR/basicsr/archs/spynet_arch.py b/StableSR/basicsr/archs/spynet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c7af133daef0496b79a57517e1942d06f2d0061
--- /dev/null
+++ b/StableSR/basicsr/archs/spynet_arch.py
@@ -0,0 +1,96 @@
+import math
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import flow_warp
+
+
+class BasicModule(nn.Module):
+ """Basic Module for SpyNet.
+ """
+
+ def __init__(self):
+ super(BasicModule, self).__init__()
+
+ self.basic_module = nn.Sequential(
+ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+
+ def forward(self, tensor_input):
+ return self.basic_module(tensor_input)
+
+
+@ARCH_REGISTRY.register()
+class SpyNet(nn.Module):
+ """SpyNet architecture.
+
+ Args:
+ load_path (str): path for pretrained SpyNet. Default: None.
+ """
+
+ def __init__(self, load_path=None):
+ super(SpyNet, self).__init__()
+ self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
+ if load_path:
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def preprocess(self, tensor_input):
+ tensor_output = (tensor_input - self.mean) / self.std
+ return tensor_output
+
+ def process(self, ref, supp):
+ flow = []
+
+ ref = [self.preprocess(ref)]
+ supp = [self.preprocess(supp)]
+
+ for level in range(5):
+ ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
+ supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
+
+ flow = ref[0].new_zeros(
+ [ref[0].size(0), 2,
+ int(math.floor(ref[0].size(2) / 2.0)),
+ int(math.floor(ref[0].size(3) / 2.0))])
+
+ for level in range(len(ref)):
+ upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
+
+ if upsampled_flow.size(2) != ref[level].size(2):
+ upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
+ if upsampled_flow.size(3) != ref[level].size(3):
+ upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
+
+ flow = self.basic_module[level](torch.cat([
+ ref[level],
+ flow_warp(
+ supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
+ upsampled_flow
+ ], 1)) + upsampled_flow
+
+ return flow
+
+ def forward(self, ref, supp):
+ assert ref.size() == supp.size()
+
+ h, w = ref.size(2), ref.size(3)
+ w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
+ h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
+
+ ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
+ supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
+
+ flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
+
+ flow[:, 0, :, :] *= float(w) / float(w_floor)
+ flow[:, 1, :, :] *= float(h) / float(h_floor)
+
+ return flow
diff --git a/StableSR/basicsr/archs/srresnet_arch.py b/StableSR/basicsr/archs/srresnet_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f571557cd7d9ba8791bd6462fccf648c57186d2
--- /dev/null
+++ b/StableSR/basicsr/archs/srresnet_arch.py
@@ -0,0 +1,65 @@
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
+
+
+@ARCH_REGISTRY.register()
+class MSRResNet(nn.Module):
+ """Modified SRResNet.
+
+ A compacted version modified from SRResNet in
+ "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
+ It uses residual blocks without BN, similar to EDSR.
+ Currently, it supports x2, x3 and x4 upsampling scale factor.
+
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_out_ch (int): Channel number of outputs. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ num_block (int): Block number in the body network. Default: 16.
+ upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
+ super(MSRResNet, self).__init__()
+ self.upscale = upscale
+
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
+
+ # upsampling
+ if self.upscale in [2, 3]:
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
+ self.pixel_shuffle = nn.PixelShuffle(self.upscale)
+ elif self.upscale == 4:
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+ self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
+ self.pixel_shuffle = nn.PixelShuffle(2)
+
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ # activation function
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
+ if self.upscale == 4:
+ default_init_weights(self.upconv2, 0.1)
+
+ def forward(self, x):
+ feat = self.lrelu(self.conv_first(x))
+ out = self.body(feat)
+
+ if self.upscale == 4:
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
+ elif self.upscale in [2, 3]:
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
+
+ out = self.conv_last(self.lrelu(self.conv_hr(out)))
+ base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
+ out += base
+ return out
diff --git a/StableSR/basicsr/archs/srvgg_arch.py b/StableSR/basicsr/archs/srvgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8fe5ceb40ed9edd35d81ee17aff86f2e3d9adb4
--- /dev/null
+++ b/StableSR/basicsr/archs/srvgg_arch.py
@@ -0,0 +1,70 @@
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+@ARCH_REGISTRY.register(suffix='basicsr')
+class SRVGGNetCompact(nn.Module):
+ """A compact VGG-style network structure for super-resolution.
+
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
+ conducted on the HR feature space.
+
+ Args:
+ num_in_ch (int): Channel number of inputs. Default: 3.
+ num_out_ch (int): Channel number of outputs. Default: 3.
+ num_feat (int): Channel number of intermediate features. Default: 64.
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
+ upscale (int): Upsampling factor. Default: 4.
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
+ """
+
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
+ super(SRVGGNetCompact, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.num_out_ch = num_out_ch
+ self.num_feat = num_feat
+ self.num_conv = num_conv
+ self.upscale = upscale
+ self.act_type = act_type
+
+ self.body = nn.ModuleList()
+ # the first conv
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
+ # the first activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the body structure
+ for _ in range(num_conv):
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
+ # activation
+ if act_type == 'relu':
+ activation = nn.ReLU(inplace=True)
+ elif act_type == 'prelu':
+ activation = nn.PReLU(num_parameters=num_feat)
+ elif act_type == 'leakyrelu':
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+ self.body.append(activation)
+
+ # the last conv
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
+ # upsample
+ self.upsampler = nn.PixelShuffle(upscale)
+
+ def forward(self, x):
+ out = x
+ for i in range(0, len(self.body)):
+ out = self.body[i](out)
+
+ out = self.upsampler(out)
+ # add the nearest upsampled image, so that the network learns the residual
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
+ out += base
+ return out
diff --git a/StableSR/basicsr/archs/stylegan2_arch.py b/StableSR/basicsr/archs/stylegan2_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ab37f5a33a2ef21641de35109c16b511a6df163
--- /dev/null
+++ b/StableSR/basicsr/archs/stylegan2_arch.py
@@ -0,0 +1,799 @@
+import math
+import random
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
+from basicsr.ops.upfirdn2d import upfirdn2d
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+class NormStyleCode(nn.Module):
+
+ def forward(self, x):
+ """Normalize the style codes.
+
+ Args:
+ x (Tensor): Style codes with shape (b, c).
+
+ Returns:
+ Tensor: Normalized tensor.
+ """
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_resample_kernel(k):
+ """Make resampling kernel for UpFirDn.
+
+ Args:
+ k (list[int]): A list indicating the 1D resample kernel magnitude.
+
+ Returns:
+ Tensor: 2D resampled kernel.
+ """
+ k = torch.tensor(k, dtype=torch.float32)
+ if k.ndim == 1:
+ k = k[None, :] * k[:, None] # to 2D kernel, outer product
+ # normalize
+ k /= k.sum()
+ return k
+
+
+class UpFirDnUpsample(nn.Module):
+ """Upsample, FIR filter, and downsample (upsampole version).
+
+ References:
+ 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
+ 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
+
+ Args:
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude.
+ factor (int): Upsampling scale factor. Default: 2.
+ """
+
+ def __init__(self, resample_kernel, factor=2):
+ super(UpFirDnUpsample, self).__init__()
+ self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
+ self.factor = factor
+
+ pad = self.kernel.shape[0] - factor
+ self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
+
+ def forward(self, x):
+ out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(factor={self.factor})')
+
+
+class UpFirDnDownsample(nn.Module):
+ """Upsample, FIR filter, and downsample (downsampole version).
+
+ Args:
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude.
+ factor (int): Downsampling scale factor. Default: 2.
+ """
+
+ def __init__(self, resample_kernel, factor=2):
+ super(UpFirDnDownsample, self).__init__()
+ self.kernel = make_resample_kernel(resample_kernel)
+ self.factor = factor
+
+ pad = self.kernel.shape[0] - factor
+ self.pad = ((pad + 1) // 2, pad // 2)
+
+ def forward(self, x):
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(factor={self.factor})')
+
+
+class UpFirDnSmooth(nn.Module):
+ """Upsample, FIR filter, and downsample (smooth version).
+
+ Args:
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude.
+ upsample_factor (int): Upsampling scale factor. Default: 1.
+ downsample_factor (int): Downsampling scale factor. Default: 1.
+ kernel_size (int): Kernel size: Default: 1.
+ """
+
+ def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1):
+ super(UpFirDnSmooth, self).__init__()
+ self.upsample_factor = upsample_factor
+ self.downsample_factor = downsample_factor
+ self.kernel = make_resample_kernel(resample_kernel)
+ if upsample_factor > 1:
+ self.kernel = self.kernel * (upsample_factor**2)
+
+ if upsample_factor > 1:
+ pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
+ self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
+ elif downsample_factor > 1:
+ pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
+ self.pad = ((pad + 1) // 2, pad // 2)
+ else:
+ raise NotImplementedError
+
+ def forward(self, x):
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
+ f', downsample_factor={self.downsample_factor})')
+
+
+class EqualLinear(nn.Module):
+ """Equalized Linear as StyleGAN2.
+
+ Args:
+ in_channels (int): Size of each sample.
+ out_channels (int): Size of each output sample.
+ bias (bool): If set to ``False``, the layer will not learn an additive
+ bias. Default: ``True``.
+ bias_init_val (float): Bias initialized value. Default: 0.
+ lr_mul (float): Learning rate multiplier. Default: 1.
+ activation (None | str): The activation after ``linear`` operation.
+ Supported: 'fused_lrelu', None. Default: None.
+ """
+
+ def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
+ super(EqualLinear, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.lr_mul = lr_mul
+ self.activation = activation
+ if self.activation not in ['fused_lrelu', None]:
+ raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
+ "Supported ones are: ['fused_lrelu', None].")
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
+
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, x):
+ if self.bias is None:
+ bias = None
+ else:
+ bias = self.bias * self.lr_mul
+ if self.activation == 'fused_lrelu':
+ out = F.linear(x, self.weight * self.scale)
+ out = fused_leaky_relu(out, bias)
+ else:
+ out = F.linear(x, self.weight * self.scale, bias=bias)
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, bias={self.bias is not None})')
+
+
+class ModulatedConv2d(nn.Module):
+ """Modulated Conv2d used in StyleGAN2.
+
+ There is no bias in ModulatedConv2d.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ num_style_feat (int): Channel number of style features.
+ demodulate (bool): Whether to demodulate in the conv layer.
+ Default: True.
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+ Default: None.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. Default: (1, 3, 3, 1).
+ eps (float): A value added to the denominator for numerical stability.
+ Default: 1e-8.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=(1, 3, 3, 1),
+ eps=1e-8):
+ super(ModulatedConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.demodulate = demodulate
+ self.sample_mode = sample_mode
+ self.eps = eps
+
+ if self.sample_mode == 'upsample':
+ self.smooth = UpFirDnSmooth(
+ resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size)
+ elif self.sample_mode == 'downsample':
+ self.smooth = UpFirDnSmooth(
+ resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
+ elif self.sample_mode is None:
+ pass
+ else:
+ raise ValueError(f'Wrong sample mode {self.sample_mode}, '
+ "supported ones are ['upsample', 'downsample', None].")
+
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+ # modulation inside each modulated conv
+ self.modulation = EqualLinear(
+ num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
+
+ self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
+ self.padding = kernel_size // 2
+
+ def forward(self, x, style):
+ """Forward function.
+
+ Args:
+ x (Tensor): Tensor with shape (b, c, h, w).
+ style (Tensor): Tensor with shape (b, num_style_feat).
+
+ Returns:
+ Tensor: Modulated tensor after convolution.
+ """
+ b, c, h, w = x.shape # c = c_in
+ # weight modulation
+ style = self.modulation(style).view(b, 1, c, 1, 1)
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
+
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
+
+ if self.sample_mode == 'upsample':
+ x = x.view(1, b * c, h, w)
+ weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size)
+ weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size)
+ out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
+ out = out.view(b, self.out_channels, *out.shape[2:4])
+ out = self.smooth(out)
+ elif self.sample_mode == 'downsample':
+ x = self.smooth(x)
+ x = x.view(1, b * c, *x.shape[2:4])
+ out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
+ out = out.view(b, self.out_channels, *out.shape[2:4])
+ else:
+ x = x.view(1, b * c, h, w)
+ # weight: (b*c_out, c_in, k, k), groups=b
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
+ out = out.view(b, self.out_channels, *out.shape[2:4])
+
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size}, '
+ f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
+
+
+class StyleConv(nn.Module):
+ """Style conv.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ num_style_feat (int): Channel number of style features.
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+ Default: None.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. Default: (1, 3, 3, 1).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=(1, 3, 3, 1)):
+ super(StyleConv, self).__init__()
+ self.modulated_conv = ModulatedConv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=demodulate,
+ sample_mode=sample_mode,
+ resample_kernel=resample_kernel)
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
+ self.activate = FusedLeakyReLU(out_channels)
+
+ def forward(self, x, style, noise=None):
+ # modulate
+ out = self.modulated_conv(x, style)
+ # noise injection
+ if noise is None:
+ b, _, h, w = out.shape
+ noise = out.new_empty(b, 1, h, w).normal_()
+ out = out + self.weight * noise
+ # activation (with bias)
+ out = self.activate(out)
+ return out
+
+
+class ToRGB(nn.Module):
+ """To RGB from features.
+
+ Args:
+ in_channels (int): Channel number of input.
+ num_style_feat (int): Channel number of style features.
+ upsample (bool): Whether to upsample. Default: True.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. Default: (1, 3, 3, 1).
+ """
+
+ def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)):
+ super(ToRGB, self).__init__()
+ if upsample:
+ self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
+ else:
+ self.upsample = None
+ self.modulated_conv = ModulatedConv2d(
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, x, style, skip=None):
+ """Forward function.
+
+ Args:
+ x (Tensor): Feature tensor with shape (b, c, h, w).
+ style (Tensor): Tensor with shape (b, num_style_feat).
+ skip (Tensor): Base/skip tensor. Default: None.
+
+ Returns:
+ Tensor: RGB images.
+ """
+ out = self.modulated_conv(x, style)
+ out = out + self.bias
+ if skip is not None:
+ if self.upsample:
+ skip = self.upsample(skip)
+ out = out + skip
+ return out
+
+
+class ConstantInput(nn.Module):
+ """Constant input.
+
+ Args:
+ num_channel (int): Channel number of constant input.
+ size (int): Spatial size of constant input.
+ """
+
+ def __init__(self, num_channel, size):
+ super(ConstantInput, self).__init__()
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
+
+ def forward(self, batch):
+ out = self.weight.repeat(batch, 1, 1, 1)
+ return out
+
+
+@ARCH_REGISTRY.register()
+class StyleGAN2Generator(nn.Module):
+ """StyleGAN2 Generator.
+
+ Args:
+ out_size (int): The spatial size of outputs.
+ num_style_feat (int): Channel number of style features. Default: 512.
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
+ channel_multiplier (int): Channel multiplier for large networks of
+ StyleGAN2. Default: 2.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. A cross production will be applied to extent 1D resample
+ kernel to 2D resample kernel. Default: (1, 3, 3, 1).
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
+ narrow (float): Narrow ratio for channels. Default: 1.0.
+ """
+
+ def __init__(self,
+ out_size,
+ num_style_feat=512,
+ num_mlp=8,
+ channel_multiplier=2,
+ resample_kernel=(1, 3, 3, 1),
+ lr_mlp=0.01,
+ narrow=1):
+ super(StyleGAN2Generator, self).__init__()
+ # Style MLP layers
+ self.num_style_feat = num_style_feat
+ style_mlp_layers = [NormStyleCode()]
+ for i in range(num_mlp):
+ style_mlp_layers.append(
+ EqualLinear(
+ num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
+ activation='fused_lrelu'))
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
+
+ channels = {
+ '4': int(512 * narrow),
+ '8': int(512 * narrow),
+ '16': int(512 * narrow),
+ '32': int(512 * narrow),
+ '64': int(256 * channel_multiplier * narrow),
+ '128': int(128 * channel_multiplier * narrow),
+ '256': int(64 * channel_multiplier * narrow),
+ '512': int(32 * channel_multiplier * narrow),
+ '1024': int(16 * channel_multiplier * narrow)
+ }
+ self.channels = channels
+
+ self.constant_input = ConstantInput(channels['4'], size=4)
+ self.style_conv1 = StyleConv(
+ channels['4'],
+ channels['4'],
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=resample_kernel)
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel)
+
+ self.log_size = int(math.log(out_size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+ self.num_latent = self.log_size * 2 - 2
+
+ self.style_convs = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channels = channels['4']
+ # noise
+ for layer_idx in range(self.num_layers):
+ resolution = 2**((layer_idx + 5) // 2)
+ shape = [1, 1, resolution, resolution]
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
+ # style convs and to_rgbs
+ for i in range(3, self.log_size + 1):
+ out_channels = channels[f'{2**i}']
+ self.style_convs.append(
+ StyleConv(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode='upsample',
+ resample_kernel=resample_kernel,
+ ))
+ self.style_convs.append(
+ StyleConv(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=resample_kernel))
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel))
+ in_channels = out_channels
+
+ def make_noise(self):
+ """Make noise for noise injection."""
+ device = self.constant_input.weight.device
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
+
+ return noises
+
+ def get_latent(self, x):
+ return self.style_mlp(x)
+
+ def mean_latent(self, num_latent):
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
+ return latent
+
+ def forward(self,
+ styles,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ truncation=1,
+ truncation_latent=None,
+ inject_index=None,
+ return_latents=False):
+ """Forward function for StyleGAN2Generator.
+
+ Args:
+ styles (list[Tensor]): Sample codes of styles.
+ input_is_latent (bool): Whether input is latent style.
+ Default: False.
+ noise (Tensor | None): Input noise or None. Default: None.
+ randomize_noise (bool): Randomize noise, used when 'noise' is
+ False. Default: True.
+ truncation (float): TODO. Default: 1.
+ truncation_latent (Tensor | None): TODO. Default: None.
+ inject_index (int | None): The injection index for mixing noise.
+ Default: None.
+ return_latents (bool): Whether to return style latents.
+ Default: False.
+ """
+ # style codes -> latents with Style MLP layer
+ if not input_is_latent:
+ styles = [self.style_mlp(s) for s in styles]
+ # noises
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers # for each style conv layer
+ else: # use the stored noise
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
+ # style truncation
+ if truncation < 1:
+ style_truncation = []
+ for style in styles:
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
+ styles = style_truncation
+ # get style latent with injection
+ if len(styles) == 1:
+ inject_index = self.num_latent
+
+ if styles[0].ndim < 3:
+ # repeat latent code for all the layers
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else: # used for encoder with different latent code for each layer
+ latent = styles[0]
+ elif len(styles) == 2: # mixing noises
+ if inject_index is None:
+ inject_index = random.randint(1, self.num_latent - 1)
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
+ latent = torch.cat([latent1, latent2], 1)
+
+ # main generation
+ out = self.constant_input(latent.shape[0])
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
+ noise[2::2], self.to_rgbs):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+ else:
+ return image, None
+
+
+class ScaledLeakyReLU(nn.Module):
+ """Scaled LeakyReLU.
+
+ Args:
+ negative_slope (float): Negative slope. Default: 0.2.
+ """
+
+ def __init__(self, negative_slope=0.2):
+ super(ScaledLeakyReLU, self).__init__()
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
+ return out * math.sqrt(2)
+
+
+class EqualConv2d(nn.Module):
+ """Equalized Linear as StyleGAN2.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ stride (int): Stride of the convolution. Default: 1
+ padding (int): Zero-padding added to both sides of the input.
+ Default: 0.
+ bias (bool): If ``True``, adds a learnable bias to the output.
+ Default: ``True``.
+ bias_init_val (float): Bias initialized value. Default: 0.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
+ super(EqualConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, x):
+ out = F.conv2d(
+ x,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size},'
+ f' stride={self.stride}, padding={self.padding}, '
+ f'bias={self.bias is not None})')
+
+
+class ConvLayer(nn.Sequential):
+ """Conv Layer used in StyleGAN2 Discriminator.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Kernel size.
+ downsample (bool): Whether downsample by a factor of 2.
+ Default: False.
+ resample_kernel (list[int]): A list indicating the 1D resample
+ kernel magnitude. A cross production will be applied to
+ extent 1D resample kernel to 2D resample kernel.
+ Default: (1, 3, 3, 1).
+ bias (bool): Whether with bias. Default: True.
+ activate (bool): Whether use activateion. Default: True.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ downsample=False,
+ resample_kernel=(1, 3, 3, 1),
+ bias=True,
+ activate=True):
+ layers = []
+ # downsample
+ if downsample:
+ layers.append(
+ UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size))
+ stride = 2
+ self.padding = 0
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+ # conv
+ layers.append(
+ EqualConv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
+ and not activate))
+ # activation
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channels))
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super(ConvLayer, self).__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ """Residual block used in StyleGAN2 Discriminator.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ resample_kernel (list[int]): A list indicating the 1D resample
+ kernel magnitude. A cross production will be applied to
+ extent 1D resample kernel to 2D resample kernel.
+ Default: (1, 3, 3, 1).
+ """
+
+ def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
+ super(ResBlock, self).__init__()
+
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
+ self.conv2 = ConvLayer(
+ in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True)
+ self.skip = ConvLayer(
+ in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.conv2(out)
+ skip = self.skip(x)
+ out = (out + skip) / math.sqrt(2)
+ return out
+
+
+@ARCH_REGISTRY.register()
+class StyleGAN2Discriminator(nn.Module):
+ """StyleGAN2 Discriminator.
+
+ Args:
+ out_size (int): The spatial size of outputs.
+ channel_multiplier (int): Channel multiplier for large networks of
+ StyleGAN2. Default: 2.
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
+ magnitude. A cross production will be applied to extent 1D resample
+ kernel to 2D resample kernel. Default: (1, 3, 3, 1).
+ stddev_group (int): For group stddev statistics. Default: 4.
+ narrow (float): Narrow ratio for channels. Default: 1.0.
+ """
+
+ def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1):
+ super(StyleGAN2Discriminator, self).__init__()
+
+ channels = {
+ '4': int(512 * narrow),
+ '8': int(512 * narrow),
+ '16': int(512 * narrow),
+ '32': int(512 * narrow),
+ '64': int(256 * channel_multiplier * narrow),
+ '128': int(128 * channel_multiplier * narrow),
+ '256': int(64 * channel_multiplier * narrow),
+ '512': int(32 * channel_multiplier * narrow),
+ '1024': int(16 * channel_multiplier * narrow)
+ }
+
+ log_size = int(math.log(out_size, 2))
+
+ conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)]
+
+ in_channels = channels[f'{out_size}']
+ for i in range(log_size, 2, -1):
+ out_channels = channels[f'{2**(i - 1)}']
+ conv_body.append(ResBlock(in_channels, out_channels, resample_kernel))
+ in_channels = out_channels
+ self.conv_body = nn.Sequential(*conv_body)
+
+ self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True)
+ self.final_linear = nn.Sequential(
+ EqualLinear(
+ channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'),
+ EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None),
+ )
+ self.stddev_group = stddev_group
+ self.stddev_feat = 1
+
+ def forward(self, x):
+ out = self.conv_body(x)
+
+ b, c, h, w = out.shape
+ # concatenate a group stddev statistics to out
+ group = min(b, self.stddev_group) # Minibatch must be divisible by (or smaller than) group_size
+ stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w)
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
+ stddev = stddev.repeat(group, 1, h, w)
+ out = torch.cat([out, stddev], 1)
+
+ out = self.final_conv(out)
+ out = out.view(b, -1)
+ out = self.final_linear(out)
+
+ return out
diff --git a/StableSR/basicsr/archs/stylegan2_bilinear_arch.py b/StableSR/basicsr/archs/stylegan2_bilinear_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2395170411f9d11f2798ac03cf6ec6eb32fe5e43
--- /dev/null
+++ b/StableSR/basicsr/archs/stylegan2_bilinear_arch.py
@@ -0,0 +1,614 @@
+import math
+import random
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+class NormStyleCode(nn.Module):
+
+ def forward(self, x):
+ """Normalize the style codes.
+
+ Args:
+ x (Tensor): Style codes with shape (b, c).
+
+ Returns:
+ Tensor: Normalized tensor.
+ """
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
+
+
+class EqualLinear(nn.Module):
+ """Equalized Linear as StyleGAN2.
+
+ Args:
+ in_channels (int): Size of each sample.
+ out_channels (int): Size of each output sample.
+ bias (bool): If set to ``False``, the layer will not learn an additive
+ bias. Default: ``True``.
+ bias_init_val (float): Bias initialized value. Default: 0.
+ lr_mul (float): Learning rate multiplier. Default: 1.
+ activation (None | str): The activation after ``linear`` operation.
+ Supported: 'fused_lrelu', None. Default: None.
+ """
+
+ def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
+ super(EqualLinear, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.lr_mul = lr_mul
+ self.activation = activation
+ if self.activation not in ['fused_lrelu', None]:
+ raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
+ "Supported ones are: ['fused_lrelu', None].")
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
+
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, x):
+ if self.bias is None:
+ bias = None
+ else:
+ bias = self.bias * self.lr_mul
+ if self.activation == 'fused_lrelu':
+ out = F.linear(x, self.weight * self.scale)
+ out = fused_leaky_relu(out, bias)
+ else:
+ out = F.linear(x, self.weight * self.scale, bias=bias)
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, bias={self.bias is not None})')
+
+
+class ModulatedConv2d(nn.Module):
+ """Modulated Conv2d used in StyleGAN2.
+
+ There is no bias in ModulatedConv2d.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ num_style_feat (int): Channel number of style features.
+ demodulate (bool): Whether to demodulate in the conv layer.
+ Default: True.
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+ Default: None.
+ eps (float): A value added to the denominator for numerical stability.
+ Default: 1e-8.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ eps=1e-8,
+ interpolation_mode='bilinear'):
+ super(ModulatedConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.demodulate = demodulate
+ self.sample_mode = sample_mode
+ self.eps = eps
+ self.interpolation_mode = interpolation_mode
+ if self.interpolation_mode == 'nearest':
+ self.align_corners = None
+ else:
+ self.align_corners = False
+
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+ # modulation inside each modulated conv
+ self.modulation = EqualLinear(
+ num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
+
+ self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
+ self.padding = kernel_size // 2
+
+ def forward(self, x, style):
+ """Forward function.
+
+ Args:
+ x (Tensor): Tensor with shape (b, c, h, w).
+ style (Tensor): Tensor with shape (b, num_style_feat).
+
+ Returns:
+ Tensor: Modulated tensor after convolution.
+ """
+ b, c, h, w = x.shape # c = c_in
+ # weight modulation
+ style = self.modulation(style).view(b, 1, c, 1, 1)
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
+
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
+
+ if self.sample_mode == 'upsample':
+ x = F.interpolate(x, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
+ elif self.sample_mode == 'downsample':
+ x = F.interpolate(x, scale_factor=0.5, mode=self.interpolation_mode, align_corners=self.align_corners)
+
+ b, c, h, w = x.shape
+ x = x.view(1, b * c, h, w)
+ # weight: (b*c_out, c_in, k, k), groups=b
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
+ out = out.view(b, self.out_channels, *out.shape[2:4])
+
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size}, '
+ f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
+
+
+class StyleConv(nn.Module):
+ """Style conv.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ num_style_feat (int): Channel number of style features.
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ interpolation_mode='bilinear'):
+ super(StyleConv, self).__init__()
+ self.modulated_conv = ModulatedConv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=demodulate,
+ sample_mode=sample_mode,
+ interpolation_mode=interpolation_mode)
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
+ self.activate = FusedLeakyReLU(out_channels)
+
+ def forward(self, x, style, noise=None):
+ # modulate
+ out = self.modulated_conv(x, style)
+ # noise injection
+ if noise is None:
+ b, _, h, w = out.shape
+ noise = out.new_empty(b, 1, h, w).normal_()
+ out = out + self.weight * noise
+ # activation (with bias)
+ out = self.activate(out)
+ return out
+
+
+class ToRGB(nn.Module):
+ """To RGB from features.
+
+ Args:
+ in_channels (int): Channel number of input.
+ num_style_feat (int): Channel number of style features.
+ upsample (bool): Whether to upsample. Default: True.
+ """
+
+ def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mode='bilinear'):
+ super(ToRGB, self).__init__()
+ self.upsample = upsample
+ self.interpolation_mode = interpolation_mode
+ if self.interpolation_mode == 'nearest':
+ self.align_corners = None
+ else:
+ self.align_corners = False
+ self.modulated_conv = ModulatedConv2d(
+ in_channels,
+ 3,
+ kernel_size=1,
+ num_style_feat=num_style_feat,
+ demodulate=False,
+ sample_mode=None,
+ interpolation_mode=interpolation_mode)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, x, style, skip=None):
+ """Forward function.
+
+ Args:
+ x (Tensor): Feature tensor with shape (b, c, h, w).
+ style (Tensor): Tensor with shape (b, num_style_feat).
+ skip (Tensor): Base/skip tensor. Default: None.
+
+ Returns:
+ Tensor: RGB images.
+ """
+ out = self.modulated_conv(x, style)
+ out = out + self.bias
+ if skip is not None:
+ if self.upsample:
+ skip = F.interpolate(
+ skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
+ out = out + skip
+ return out
+
+
+class ConstantInput(nn.Module):
+ """Constant input.
+
+ Args:
+ num_channel (int): Channel number of constant input.
+ size (int): Spatial size of constant input.
+ """
+
+ def __init__(self, num_channel, size):
+ super(ConstantInput, self).__init__()
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
+
+ def forward(self, batch):
+ out = self.weight.repeat(batch, 1, 1, 1)
+ return out
+
+
+@ARCH_REGISTRY.register(suffix='basicsr')
+class StyleGAN2GeneratorBilinear(nn.Module):
+ """StyleGAN2 Generator.
+
+ Args:
+ out_size (int): The spatial size of outputs.
+ num_style_feat (int): Channel number of style features. Default: 512.
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
+ channel_multiplier (int): Channel multiplier for large networks of
+ StyleGAN2. Default: 2.
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
+ narrow (float): Narrow ratio for channels. Default: 1.0.
+ """
+
+ def __init__(self,
+ out_size,
+ num_style_feat=512,
+ num_mlp=8,
+ channel_multiplier=2,
+ lr_mlp=0.01,
+ narrow=1,
+ interpolation_mode='bilinear'):
+ super(StyleGAN2GeneratorBilinear, self).__init__()
+ # Style MLP layers
+ self.num_style_feat = num_style_feat
+ style_mlp_layers = [NormStyleCode()]
+ for i in range(num_mlp):
+ style_mlp_layers.append(
+ EqualLinear(
+ num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
+ activation='fused_lrelu'))
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
+
+ channels = {
+ '4': int(512 * narrow),
+ '8': int(512 * narrow),
+ '16': int(512 * narrow),
+ '32': int(512 * narrow),
+ '64': int(256 * channel_multiplier * narrow),
+ '128': int(128 * channel_multiplier * narrow),
+ '256': int(64 * channel_multiplier * narrow),
+ '512': int(32 * channel_multiplier * narrow),
+ '1024': int(16 * channel_multiplier * narrow)
+ }
+ self.channels = channels
+
+ self.constant_input = ConstantInput(channels['4'], size=4)
+ self.style_conv1 = StyleConv(
+ channels['4'],
+ channels['4'],
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ interpolation_mode=interpolation_mode)
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
+
+ self.log_size = int(math.log(out_size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+ self.num_latent = self.log_size * 2 - 2
+
+ self.style_convs = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channels = channels['4']
+ # noise
+ for layer_idx in range(self.num_layers):
+ resolution = 2**((layer_idx + 5) // 2)
+ shape = [1, 1, resolution, resolution]
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
+ # style convs and to_rgbs
+ for i in range(3, self.log_size + 1):
+ out_channels = channels[f'{2**i}']
+ self.style_convs.append(
+ StyleConv(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode='upsample',
+ interpolation_mode=interpolation_mode))
+ self.style_convs.append(
+ StyleConv(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ num_style_feat=num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ interpolation_mode=interpolation_mode))
+ self.to_rgbs.append(
+ ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
+ in_channels = out_channels
+
+ def make_noise(self):
+ """Make noise for noise injection."""
+ device = self.constant_input.weight.device
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
+
+ return noises
+
+ def get_latent(self, x):
+ return self.style_mlp(x)
+
+ def mean_latent(self, num_latent):
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
+ return latent
+
+ def forward(self,
+ styles,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ truncation=1,
+ truncation_latent=None,
+ inject_index=None,
+ return_latents=False):
+ """Forward function for StyleGAN2Generator.
+
+ Args:
+ styles (list[Tensor]): Sample codes of styles.
+ input_is_latent (bool): Whether input is latent style.
+ Default: False.
+ noise (Tensor | None): Input noise or None. Default: None.
+ randomize_noise (bool): Randomize noise, used when 'noise' is
+ False. Default: True.
+ truncation (float): TODO. Default: 1.
+ truncation_latent (Tensor | None): TODO. Default: None.
+ inject_index (int | None): The injection index for mixing noise.
+ Default: None.
+ return_latents (bool): Whether to return style latents.
+ Default: False.
+ """
+ # style codes -> latents with Style MLP layer
+ if not input_is_latent:
+ styles = [self.style_mlp(s) for s in styles]
+ # noises
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers # for each style conv layer
+ else: # use the stored noise
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
+ # style truncation
+ if truncation < 1:
+ style_truncation = []
+ for style in styles:
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
+ styles = style_truncation
+ # get style latent with injection
+ if len(styles) == 1:
+ inject_index = self.num_latent
+
+ if styles[0].ndim < 3:
+ # repeat latent code for all the layers
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ else: # used for encoder with different latent code for each layer
+ latent = styles[0]
+ elif len(styles) == 2: # mixing noises
+ if inject_index is None:
+ inject_index = random.randint(1, self.num_latent - 1)
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
+ latent = torch.cat([latent1, latent2], 1)
+
+ # main generation
+ out = self.constant_input(latent.shape[0])
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
+ noise[2::2], self.to_rgbs):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+ else:
+ return image, None
+
+
+class ScaledLeakyReLU(nn.Module):
+ """Scaled LeakyReLU.
+
+ Args:
+ negative_slope (float): Negative slope. Default: 0.2.
+ """
+
+ def __init__(self, negative_slope=0.2):
+ super(ScaledLeakyReLU, self).__init__()
+ self.negative_slope = negative_slope
+
+ def forward(self, x):
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
+ return out * math.sqrt(2)
+
+
+class EqualConv2d(nn.Module):
+ """Equalized Linear as StyleGAN2.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Size of the convolving kernel.
+ stride (int): Stride of the convolution. Default: 1
+ padding (int): Zero-padding added to both sides of the input.
+ Default: 0.
+ bias (bool): If ``True``, adds a learnable bias to the output.
+ Default: ``True``.
+ bias_init_val (float): Bias initialized value. Default: 0.
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
+ super(EqualConv2d, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
+
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, x):
+ out = F.conv2d(
+ x,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size},'
+ f' stride={self.stride}, padding={self.padding}, '
+ f'bias={self.bias is not None})')
+
+
+class ConvLayer(nn.Sequential):
+ """Conv Layer used in StyleGAN2 Discriminator.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ kernel_size (int): Kernel size.
+ downsample (bool): Whether downsample by a factor of 2.
+ Default: False.
+ bias (bool): Whether with bias. Default: True.
+ activate (bool): Whether use activateion. Default: True.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ downsample=False,
+ bias=True,
+ activate=True,
+ interpolation_mode='bilinear'):
+ layers = []
+ self.interpolation_mode = interpolation_mode
+ # downsample
+ if downsample:
+ if self.interpolation_mode == 'nearest':
+ self.align_corners = None
+ else:
+ self.align_corners = False
+
+ layers.append(
+ torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
+ stride = 1
+ self.padding = kernel_size // 2
+ # conv
+ layers.append(
+ EqualConv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
+ and not activate))
+ # activation
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channels))
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super(ConvLayer, self).__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ """Residual block used in StyleGAN2 Discriminator.
+
+ Args:
+ in_channels (int): Channel number of the input.
+ out_channels (int): Channel number of the output.
+ """
+
+ def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
+ super(ResBlock, self).__init__()
+
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
+ self.conv2 = ConvLayer(
+ in_channels,
+ out_channels,
+ 3,
+ downsample=True,
+ interpolation_mode=interpolation_mode,
+ bias=True,
+ activate=True)
+ self.skip = ConvLayer(
+ in_channels,
+ out_channels,
+ 1,
+ downsample=True,
+ interpolation_mode=interpolation_mode,
+ bias=False,
+ activate=False)
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.conv2(out)
+ skip = self.skip(x)
+ out = (out + skip) / math.sqrt(2)
+ return out
diff --git a/StableSR/basicsr/archs/swinir_arch.py b/StableSR/basicsr/archs/swinir_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..3917fa2c7408e1f5b55b9930c643a9af920a4d81
--- /dev/null
+++ b/StableSR/basicsr/archs/swinir_arch.py
@@ -0,0 +1,956 @@
+# Modified from https://github.com/JingyunLiang/SwinIR
+# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
+# Originally Written by Ze Liu, Modified by Jingyun Liang.
+
+import math
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import to_2tuple, trunc_normal_
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (b, h, w, c)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*b, window_size, window_size, c)
+ """
+ b, h, w, c = x.shape
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
+ return windows
+
+
+def window_reverse(windows, window_size, h, w):
+ """
+ Args:
+ windows: (num_windows*b, window_size, window_size, c)
+ window_size (int): Window size
+ h (int): Height of image
+ w (int): Width of image
+
+ Returns:
+ x: (b, h, w, c)
+ """
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer('relative_position_index', relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*b, n, c)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ b_, n, c = x.shape
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nw = mask.shape[0]
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, n, n)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, n):
+ # calculate flops for 1 window with token length of n
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += n * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * n * (self.dim // self.num_heads) * n
+ # x = (attn @ v)
+ flops += self.num_heads * n * n * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += n * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ attn_mask = self.calculate_mask(self.input_resolution)
+ else:
+ attn_mask = None
+
+ self.register_buffer('attn_mask', attn_mask)
+
+ def calculate_mask(self, x_size):
+ # calculate attention mask for SW-MSA
+ h, w = x_size
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x, x_size):
+ h, w = x_size
+ b, _, c = x.shape
+ # assert seq_len == h * w, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(b, h, w, c)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
+
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+ if self.input_resolution == x_size:
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
+ else:
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(b, h * w, c)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
+ f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
+
+ def flops(self):
+ flops = 0
+ h, w = self.input_resolution
+ # norm1
+ flops += self.dim * h * w
+ # W-MSA/SW-MSA
+ nw = h * w / self.window_size / self.window_size
+ flops += nw * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * h * w
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: b, h*w, c
+ """
+ h, w = self.input_resolution
+ b, seq_len, c = x.shape
+ assert seq_len == h * w, 'input feature has wrong size'
+ assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
+
+ x = x.view(b, h, w, c)
+
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f'input_resolution={self.input_resolution}, dim={self.dim}'
+
+ def flops(self):
+ h, w = self.input_resolution
+ flops = h * w * self.dim
+ flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer) for i in range(depth)
+ ])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x, x_size)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class RSTB(nn.Module):
+ """Residual Swin Transformer Block (RSTB).
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ img_size=224,
+ patch_size=4,
+ resi_connection='1conv'):
+ super(RSTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(
+ dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(
+ nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ h, w = self.input_resolution
+ flops += h * w * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ flops = 0
+ h, w = self.img_size
+ if self.norm is not None:
+ flops += h * w * self.embed_dim
+ return flops
+
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ h, w = self.input_resolution
+ flops = h * w * self.num_feat * 3 * 9
+ return flops
+
+
+@ARCH_REGISTRY.register()
+class SwinIR(nn.Module):
+ r""" SwinIR
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 64
+ patch_size (int | tuple(int)): Patch size. Default: 1
+ in_chans (int): Number of input image channels. Default: 3
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+ img_range: Image range. 1. or 255.
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+ """
+
+ def __init__(self,
+ img_size=64,
+ patch_size=1,
+ in_chans=3,
+ embed_dim=96,
+ depths=(6, 6, 6, 6),
+ num_heads=(6, 6, 6, 6),
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False,
+ upscale=2,
+ img_range=1.,
+ upsampler='',
+ resi_connection='1conv',
+ **kwargs):
+ super(SwinIR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ if in_chans == 3:
+ rgb_mean = (0.4488, 0.4371, 0.4040)
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+
+ # ------------------------- 1, shallow feature extraction ------------------------- #
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ # ------------------------- 2, deep feature extraction ------------------------- #
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=embed_dim,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=embed_dim,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Swin Transformer blocks (RSTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(
+ dim=embed_dim,
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection)
+ self.layers.append(layer)
+ self.norm = norm_layer(self.num_features)
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ assert self.upscale == 4, 'only support x4 now.'
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # b seq_len c
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+
+ return x
+
+ def flops(self):
+ flops = 0
+ h, w = self.patches_resolution
+ flops += h * w * 3 * self.embed_dim * 9
+ flops += self.patch_embed.flops()
+ for layer in self.layers:
+ flops += layer.flops()
+ flops += h * w * 3 * self.embed_dim * self.embed_dim
+ flops += self.upsample.flops()
+ return flops
+
+
+if __name__ == '__main__':
+ upscale = 4
+ window_size = 8
+ height = (1024 // upscale // window_size + 1) * window_size
+ width = (720 // upscale // window_size + 1) * window_size
+ model = SwinIR(
+ upscale=2,
+ img_size=(height, width),
+ window_size=window_size,
+ img_range=1.,
+ depths=[6, 6, 6, 6],
+ embed_dim=60,
+ num_heads=[6, 6, 6, 6],
+ mlp_ratio=2,
+ upsampler='pixelshuffledirect')
+ print(model)
+ print(height, width, model.flops() / 1e9)
+
+ x = torch.randn((1, 3, height, width))
+ x = model(x)
+ print(x.shape)
diff --git a/StableSR/basicsr/archs/tof_arch.py b/StableSR/basicsr/archs/tof_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..a90a64d89386e19f92c987bbe2133472991d764a
--- /dev/null
+++ b/StableSR/basicsr/archs/tof_arch.py
@@ -0,0 +1,172 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import flow_warp
+
+
+class BasicModule(nn.Module):
+ """Basic module of SPyNet.
+
+ Note that unlike the architecture in spynet_arch.py, the basic module
+ here contains batch normalization.
+ """
+
+ def __init__(self):
+ super(BasicModule, self).__init__()
+ self.basic_module = nn.Sequential(
+ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(64), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(16), nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+
+ def forward(self, tensor_input):
+ """
+ Args:
+ tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
+ 8 channels contain:
+ [reference image (3), neighbor image (3), initial flow (2)].
+
+ Returns:
+ Tensor: Estimated flow with shape (b, 2, h, w)
+ """
+ return self.basic_module(tensor_input)
+
+
+class SPyNetTOF(nn.Module):
+ """SPyNet architecture for TOF.
+
+ Note that this implementation is specifically for TOFlow. Please use :file:`spynet_arch.py` for general use.
+ They differ in the following aspects:
+
+ 1. The basic modules here contain BatchNorm.
+ 2. Normalization and denormalization are not done here, as they are done in TOFlow.
+
+ ``Paper: Optical Flow Estimation using a Spatial Pyramid Network``
+
+ Reference: https://github.com/Coldog2333/pytoflow
+
+ Args:
+ load_path (str): Path for pretrained SPyNet. Default: None.
+ """
+
+ def __init__(self, load_path=None):
+ super(SPyNetTOF, self).__init__()
+
+ self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
+ if load_path:
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
+
+ def forward(self, ref, supp):
+ """
+ Args:
+ ref (Tensor): Reference image with shape of (b, 3, h, w).
+ supp: The supporting image to be warped: (b, 3, h, w).
+
+ Returns:
+ Tensor: Estimated optical flow: (b, 2, h, w).
+ """
+ num_batches, _, h, w = ref.size()
+ ref = [ref]
+ supp = [supp]
+
+ # generate downsampled frames
+ for _ in range(3):
+ ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
+ supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
+
+ # flow computation
+ flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
+ for i in range(4):
+ flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
+ flow = flow_up + self.basic_module[i](
+ torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
+ return flow
+
+
+@ARCH_REGISTRY.register()
+class TOFlow(nn.Module):
+ """PyTorch implementation of TOFlow.
+
+ In TOFlow, the LR frames are pre-upsampled and have the same size with the GT frames.
+
+ ``Paper: Video Enhancement with Task-Oriented Flow``
+
+ Reference: https://github.com/anchen1011/toflow
+
+ Reference: https://github.com/Coldog2333/pytoflow
+
+ Args:
+ adapt_official_weights (bool): Whether to adapt the weights translated
+ from the official implementation. Set to false if you want to
+ train from scratch. Default: False
+ """
+
+ def __init__(self, adapt_official_weights=False):
+ super(TOFlow, self).__init__()
+ self.adapt_official_weights = adapt_official_weights
+ self.ref_idx = 0 if adapt_official_weights else 3
+
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ # flow estimation module
+ self.spynet = SPyNetTOF()
+
+ # reconstruction module
+ self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
+ self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
+ self.conv_3 = nn.Conv2d(64, 64, 1)
+ self.conv_4 = nn.Conv2d(64, 3, 1)
+
+ # activation function
+ self.relu = nn.ReLU(inplace=True)
+
+ def normalize(self, img):
+ return (img - self.mean) / self.std
+
+ def denormalize(self, img):
+ return img * self.std + self.mean
+
+ def forward(self, lrs):
+ """
+ Args:
+ lrs: Input lr frames: (b, 7, 3, h, w).
+
+ Returns:
+ Tensor: SR frame: (b, 3, h, w).
+ """
+ # In the official implementation, the 0-th frame is the reference frame
+ if self.adapt_official_weights:
+ lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
+
+ num_batches, num_lrs, _, h, w = lrs.size()
+
+ lrs = self.normalize(lrs.view(-1, 3, h, w))
+ lrs = lrs.view(num_batches, num_lrs, 3, h, w)
+
+ lr_ref = lrs[:, self.ref_idx, :, :, :]
+ lr_aligned = []
+ for i in range(7): # 7 frames
+ if i == self.ref_idx:
+ lr_aligned.append(lr_ref)
+ else:
+ lr_supp = lrs[:, i, :, :, :]
+ flow = self.spynet(lr_ref, lr_supp)
+ lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
+
+ # reconstruction
+ hr = torch.stack(lr_aligned, dim=1)
+ hr = hr.view(num_batches, -1, h, w)
+ hr = self.relu(self.conv_1(hr))
+ hr = self.relu(self.conv_2(hr))
+ hr = self.relu(self.conv_3(hr))
+ hr = self.conv_4(hr) + lr_ref
+
+ return self.denormalize(hr)
diff --git a/StableSR/basicsr/archs/vgg_arch.py b/StableSR/basicsr/archs/vgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..05200334e477e59feefd1e4a0b5e94204e4eb2fa
--- /dev/null
+++ b/StableSR/basicsr/archs/vgg_arch.py
@@ -0,0 +1,161 @@
+import os
+import torch
+from collections import OrderedDict
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
+NAMES = {
+ 'vgg11': [
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
+ 'pool5'
+ ],
+ 'vgg13': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+ ],
+ 'vgg16': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+ 'pool5'
+ ],
+ 'vgg19': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
+ ]
+}
+
+
+def insert_bn(names):
+ """Insert bn layer after each conv.
+
+ Args:
+ names (list): The list of layer names.
+
+ Returns:
+ list: The list of layer names with bn layers.
+ """
+ names_bn = []
+ for name in names:
+ names_bn.append(name)
+ if 'conv' in name:
+ position = name.replace('conv', '')
+ names_bn.append('bn' + position)
+ return names_bn
+
+
+@ARCH_REGISTRY.register()
+class VGGFeatureExtractor(nn.Module):
+ """VGG network for feature extraction.
+
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+
+ Args:
+ layer_name_list (list[str]): Forward function returns the corresponding
+ features according to the layer_name_list.
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image. Importantly,
+ the input feature must in the range [0, 1]. Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ remove_pooling (bool): If true, the max pooling operations in VGG net
+ will be removed. Default: False.
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
+ """
+
+ def __init__(self,
+ layer_name_list,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2):
+ super(VGGFeatureExtractor, self).__init__()
+
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ self.names = NAMES[vgg_type.replace('_bn', '')]
+ if 'bn' in vgg_type:
+ self.names = insert_bn(self.names)
+
+ # only borrow layers that will be used to avoid unused params
+ max_idx = 0
+ for v in layer_name_list:
+ idx = self.names.index(v)
+ if idx > max_idx:
+ max_idx = idx
+
+ if os.path.exists(VGG_PRETRAIN_PATH):
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
+ vgg_net.load_state_dict(state_dict)
+ else:
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+ features = vgg_net.features[:max_idx + 1]
+
+ modified_net = OrderedDict()
+ for k, v in zip(self.names, features):
+ if 'pool' in k:
+ # if remove_pooling is true, pooling operation will be removed
+ if remove_pooling:
+ continue
+ else:
+ # in some cases, we may want to change the default stride
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+ else:
+ modified_net[k] = v
+
+ self.vgg_net = nn.Sequential(modified_net)
+
+ if not requires_grad:
+ self.vgg_net.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.vgg_net.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ if self.range_norm:
+ x = (x + 1) / 2
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+
+ output = {}
+ for key, layer in self.vgg_net._modules.items():
+ x = layer(x)
+ if key in self.layer_name_list:
+ output[key] = x.clone()
+
+ return output
diff --git a/StableSR/basicsr/data/__init__.py b/StableSR/basicsr/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..510df16771d153f61fbf2126baac24f69d3de7e4
--- /dev/null
+++ b/StableSR/basicsr/data/__init__.py
@@ -0,0 +1,101 @@
+import importlib
+import numpy as np
+import random
+import torch
+import torch.utils.data
+from copy import deepcopy
+from functools import partial
+from os import path as osp
+
+from basicsr.data.prefetch_dataloader import PrefetchDataLoader
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import DATASET_REGISTRY
+
+__all__ = ['build_dataset', 'build_dataloader']
+
+# automatically scan and import dataset modules for registry
+# scan all the files under the data folder with '_dataset' in file names
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
+
+
+def build_dataset(dataset_opt):
+ """Build dataset from options.
+
+ Args:
+ dataset_opt (dict): Configuration for dataset. It must contain:
+ name (str): Dataset name.
+ type (str): Dataset type.
+ """
+ dataset_opt = deepcopy(dataset_opt)
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
+ logger = get_root_logger()
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
+ return dataset
+
+
+def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
+ """Build dataloader.
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset.
+ dataset_opt (dict): Dataset options. It contains the following keys:
+ phase (str): 'train' or 'val'.
+ num_worker_per_gpu (int): Number of workers for each GPU.
+ batch_size_per_gpu (int): Training batch size for each GPU.
+ num_gpu (int): Number of GPUs. Used only in the train phase.
+ Default: 1.
+ dist (bool): Whether in distributed training. Used only in the train
+ phase. Default: False.
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
+ seed (int | None): Seed. Default: None
+ """
+ phase = dataset_opt['phase']
+ rank, _ = get_dist_info()
+ if phase == 'train':
+ if dist: # distributed training
+ batch_size = dataset_opt['batch_size_per_gpu']
+ num_workers = dataset_opt['num_worker_per_gpu']
+ else: # non-distributed training
+ multiplier = 1 if num_gpu == 0 else num_gpu
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
+ dataloader_args = dict(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ sampler=sampler,
+ drop_last=True)
+ if sampler is None:
+ dataloader_args['shuffle'] = True
+ dataloader_args['worker_init_fn'] = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+ elif phase in ['val', 'test']: # validation
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+ else:
+ raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
+
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
+ dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
+
+ prefetch_mode = dataset_opt.get('prefetch_mode')
+ if prefetch_mode == 'cpu': # CPUPrefetcher
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
+ logger = get_root_logger()
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
+ else:
+ # prefetch_mode=None: Normal dataloader
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
+ return torch.utils.data.DataLoader(**dataloader_args)
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # Set the worker seed to num_workers * rank + worker_id + seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/StableSR/basicsr/data/data_sampler.py b/StableSR/basicsr/data/data_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..575452d9f844a928f7f42296c81635cfbadec7c2
--- /dev/null
+++ b/StableSR/basicsr/data/data_sampler.py
@@ -0,0 +1,48 @@
+import math
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+class EnlargedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ Modified from torch.utils.data.distributed.DistributedSampler
+ Support enlarging the dataset for iteration-based training, for saving
+ time when restart the dataloader after each epoch
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
+ num_replicas (int | None): Number of processes participating in
+ the training. It is usually the world_size.
+ rank (int | None): Rank of the current process within num_replicas.
+ ratio (int): Enlarging ratio. Default: 1.
+ """
+
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(self.total_size, generator=g).tolist()
+
+ dataset_size = len(self.dataset)
+ indices = [v % dataset_size for v in indices]
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/StableSR/basicsr/data/data_util.py b/StableSR/basicsr/data/data_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..dce2562fb9f99475c44e9185f50018a428859214
--- /dev/null
+++ b/StableSR/basicsr/data/data_util.py
@@ -0,0 +1,362 @@
+import cv2
+import numpy as np
+import torch
+from os import path as osp
+from torch.nn import functional as F
+
+from basicsr.data.transforms import mod_crop
+from basicsr.utils import img2tensor, scandir
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
+ """Read a sequence of images from a given folder path.
+
+ Args:
+ path (list[str] | str): List of image paths or image folder path.
+ require_mod_crop (bool): Require mod crop for each image.
+ Default: False.
+ scale (int): Scale factor for mod_crop. Default: 1.
+ return_imgname(bool): Whether return image names. Default False.
+
+ Returns:
+ Tensor: size (t, c, h, w), RGB, [0, 1].
+ list[str]: Returned image name list.
+ """
+ if isinstance(path, list):
+ img_paths = path
+ else:
+ img_paths = sorted(list(scandir(path, full_path=True)))
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+
+ if require_mod_crop:
+ imgs = [mod_crop(img, scale) for img in imgs]
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+ imgs = torch.stack(imgs, dim=0)
+
+ if return_imgname:
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
+ return imgs, imgnames
+ else:
+ return imgs
+
+
+def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
+ """Generate an index list for reading `num_frames` frames from a sequence
+ of images.
+
+ Args:
+ crt_idx (int): Current center index.
+ max_frame_num (int): Max number of the sequence of images (from 1).
+ num_frames (int): Reading num_frames frames.
+ padding (str): Padding mode, one of
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
+ Examples: current_idx = 0, num_frames = 5
+ The generated frame indices under different padding mode:
+ replicate: [0, 0, 0, 1, 2]
+ reflection: [2, 1, 0, 1, 2]
+ reflection_circle: [4, 3, 0, 1, 2]
+ circle: [3, 4, 0, 1, 2]
+
+ Returns:
+ list[int]: A list of indices.
+ """
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
+
+ max_frame_num = max_frame_num - 1 # start from 0
+ num_pad = num_frames // 2
+
+ indices = []
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
+ if i < 0:
+ if padding == 'replicate':
+ pad_idx = 0
+ elif padding == 'reflection':
+ pad_idx = -i
+ elif padding == 'reflection_circle':
+ pad_idx = crt_idx + num_pad - i
+ else:
+ pad_idx = num_frames + i
+ elif i > max_frame_num:
+ if padding == 'replicate':
+ pad_idx = max_frame_num
+ elif padding == 'reflection':
+ pad_idx = max_frame_num * 2 - i
+ elif padding == 'reflection_circle':
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
+ else:
+ pad_idx = i - num_frames
+ else:
+ pad_idx = i
+ indices.append(pad_idx)
+ return indices
+
+
+def paired_paths_from_lmdb(folders, keys):
+ """Generate paired paths from lmdb files.
+
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
+
+ ::
+
+ lq.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records
+ 1)image name (with extension),
+ 2)image shape,
+ 3)compression level, separated by a white space.
+ Example: `baboon.png (120,125,3) 1`
+
+ We use the image name without extension as the lmdb key.
+ Note that we use the same key for the corresponding lq and gt images.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ Note that this key is different from lmdb keys.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
+ f'formats. But received {input_key}: {input_folder}; '
+ f'{gt_key}: {gt_folder}')
+ # ensure that the two meta_info files are the same
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
+ else:
+ paths = []
+ for lmdb_key in sorted(input_lmdb_keys):
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
+ return paths
+
+
+def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
+ """Generate paired paths from an meta information file.
+
+ Each line in the meta information file contains the image names and
+ image shape (usually for gt), separated by a white space.
+
+ Example of an meta information file:
+ ```
+ 0001_s001.png (480,480,3)
+ 0001_s002.png (480,480,3)
+ ```
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ meta_info_file (str): Path to the meta information file.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ with open(meta_info_file, 'r') as fin:
+ gt_names = [line.strip().split(' ')[0] for line in fin]
+
+ paths = []
+ for gt_name in gt_names:
+ basename, ext = osp.splitext(osp.basename(gt_name))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ gt_path = osp.join(gt_folder, gt_name)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+def paired_paths_from_meta_info_file_2(folders, keys, meta_info_file, filename_tmpl):
+ """Generate paired paths from an meta information file.
+
+ Each line in the meta information file contains the image names and
+ image shape (usually for gt), separated by a white space.
+
+ Example of an meta information file:
+ ```
+ 0001_s001.png (480,480,3)
+ 0001_s002.png (480,480,3)
+ ```
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ meta_info_file (str): Path to the meta information file.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ with open(meta_info_file, 'r') as fin:
+ gt_names = [line.strip().split(' ')[0] for line in fin]
+ with open(meta_info_file, 'r') as fin:
+ input_names = [line.strip().split(' ')[1] for line in fin]
+ paths = []
+ for i in range(len(gt_names)):
+ gt_name = gt_names[i]
+ lq_name = input_names[i]
+ basename, ext = osp.splitext(osp.basename(gt_name))
+ basename = gt_name[:-len(ext)]
+ gt_path = osp.join(gt_folder, gt_name)
+ basename, ext = osp.splitext(osp.basename(lq_name))
+ basename = lq_name[:-len(ext)]
+ input_path = osp.join(input_folder, lq_name)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+def paired_paths_from_folder(folders, keys, filename_tmpl):
+ """Generate paired paths from folders.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ input_paths = list(scandir(input_folder))
+ gt_paths = list(scandir(gt_folder))
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
+ f'{len(input_paths)}, {len(gt_paths)}.')
+ paths = []
+ for gt_path in gt_paths:
+ basename, ext = osp.splitext(osp.basename(gt_path))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
+ gt_path = osp.join(gt_folder, gt_path)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paths_from_folder(folder):
+ """Generate paths from folder.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+
+ paths = list(scandir(folder))
+ paths = [osp.join(folder, path) for path in paths]
+ return paths
+
+
+def paths_from_lmdb(folder):
+ """Generate paths from lmdb.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ if not folder.endswith('.lmdb'):
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
+ paths = [line.split('.')[0] for line in fin]
+ return paths
+
+
+def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
+ """Generate Gaussian kernel used in `duf_downsample`.
+
+ Args:
+ kernel_size (int): Kernel size. Default: 13.
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
+
+ Returns:
+ np.array: The Gaussian kernel.
+ """
+ from scipy.ndimage import filters as filters
+ kernel = np.zeros((kernel_size, kernel_size))
+ # set element at the middle to one, a dirac delta
+ kernel[kernel_size // 2, kernel_size // 2] = 1
+ # gaussian-smooth the dirac, resulting in a gaussian filter
+ return filters.gaussian_filter(kernel, sigma)
+
+
+def duf_downsample(x, kernel_size=13, scale=4):
+ """Downsamping with Gaussian kernel used in the DUF official code.
+
+ Args:
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
+ kernel_size (int): Kernel size. Default: 13.
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
+ Default: 4.
+
+ Returns:
+ Tensor: DUF downsampled frames.
+ """
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
+
+ squeeze_flag = False
+ if x.ndim == 4:
+ squeeze_flag = True
+ x = x.unsqueeze(0)
+ b, t, c, h, w = x.size()
+ x = x.view(-1, 1, h, w)
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
+
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
+ x = F.conv2d(x, gaussian_filter, stride=scale)
+ x = x[:, :, 2:-2, 2:-2]
+ x = x.view(b, t, c, x.size(2), x.size(3))
+ if squeeze_flag:
+ x = x.squeeze(0)
+ return x
diff --git a/StableSR/basicsr/data/degradations.py b/StableSR/basicsr/data/degradations.py
new file mode 100644
index 0000000000000000000000000000000000000000..5db40fb080908e9a0de503b9c9518710f89e2e0d
--- /dev/null
+++ b/StableSR/basicsr/data/degradations.py
@@ -0,0 +1,935 @@
+import cv2
+import math
+import numpy as np
+import random
+import torch
+from scipy import special
+from scipy.stats import multivariate_normal
+from torchvision.transforms.functional_tensor import rgb_to_grayscale
+
+# -------------------------------------------------------------------- #
+# --------------------------- blur kernels --------------------------- #
+# -------------------------------------------------------------------- #
+
+
+# --------------------------- util functions --------------------------- #
+def sigma_matrix2(sig_x, sig_y, theta):
+ """Calculate the rotated sigma matrix (two dimensional matrix).
+
+ Args:
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+
+ Returns:
+ ndarray: Rotated sigma matrix.
+ """
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
+
+
+def mesh_grid(kernel_size):
+ """Generate the mesh grid, centering at zero.
+
+ Args:
+ kernel_size (int):
+
+ Returns:
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
+ xx (ndarray): with the shape (kernel_size, kernel_size)
+ yy (ndarray): with the shape (kernel_size, kernel_size)
+ """
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
+ xx, yy = np.meshgrid(ax, ax)
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
+ 1))).reshape(kernel_size, kernel_size, 2)
+ return xy, xx, yy
+
+
+def pdf2(sigma_matrix, grid):
+ """Calculate PDF of the bivariate Gaussian distribution.
+
+ Args:
+ sigma_matrix (ndarray): with the shape (2, 2)
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+
+ Returns:
+ kernel (ndarrray): un-normalized kernel.
+ """
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
+ return kernel
+
+
+def cdf2(d_matrix, grid):
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
+ Used in skewed Gaussian distribution.
+
+ Args:
+ d_matrix (ndarrasy): skew matrix.
+ grid (ndarray): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size.
+
+ Returns:
+ cdf (ndarray): skewed cdf.
+ """
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
+ grid = np.dot(grid, d_matrix)
+ cdf = rv.cdf(grid)
+ return cdf
+
+
+def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
+
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+ isotropic (bool):
+
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ kernel = pdf2(sigma_matrix, grid)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
+ """Generate a bivariate generalized Gaussian kernel.
+
+ ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
+
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
+ """Generate a plateau-like anisotropic kernel.
+
+ 1 / (1+x^(beta))
+
+ Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
+
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
+
+ Args:
+ kernel_size (int):
+ sig_x (float):
+ sig_y (float):
+ theta (float): Radian measurement.
+ beta (float): shape parameter, beta = 1 is the normal distribution.
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
+ with the shape (K, K, 2), K is the kernel size. Default: None
+
+ Returns:
+ kernel (ndarray): normalized kernel.
+ """
+ if grid is None:
+ grid, _, _ = mesh_grid(kernel_size)
+ if isotropic:
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
+ else:
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
+ inverse_sigma = np.linalg.inv(sigma_matrix)
+ kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
+ kernel = kernel / np.sum(kernel)
+ return kernel
+
+
+def random_bivariate_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ noise_range=None,
+ isotropic=True,
+ return_sigma=False):
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
+
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ if isotropic is False:
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ else:
+ sigma_y = sigma_x
+ rotation = 0
+
+ kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if not return_sigma:
+ return kernel
+ else:
+ return kernel, [sigma_x, sigma_y]
+
+
+def random_bivariate_generalized_Gaussian(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ isotropic=True,
+ return_sigma=False):
+ """Randomly generate bivariate generalized Gaussian kernels.
+
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ if isotropic is False:
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ else:
+ sigma_y = sigma_x
+ rotation = 0
+
+ # assume beta_range[0] < 1 < beta_range[1]
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
+
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+ if not return_sigma:
+ return kernel
+ else:
+ return kernel, [sigma_x, sigma_y]
+
+
+def random_bivariate_plateau(kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ beta_range,
+ noise_range=None,
+ isotropic=True,
+ return_sigma=False):
+ """Randomly generate bivariate plateau kernels.
+
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
+
+ Args:
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi/2, math.pi/2]
+ beta_range (tuple): [1, 4]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
+ if isotropic is False:
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
+ else:
+ sigma_y = sigma_x
+ rotation = 0
+
+ # TODO: this may be not proper
+ if np.random.uniform() < 0.5:
+ beta = np.random.uniform(beta_range[0], 1)
+ else:
+ beta = np.random.uniform(1, beta_range[1])
+
+ kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
+ # add multiplicative noise
+ if noise_range is not None:
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
+ kernel = kernel * noise
+ kernel = kernel / np.sum(kernel)
+
+ if not return_sigma:
+ return kernel
+ else:
+ return kernel, [sigma_x, sigma_y]
+
+
+def random_mixed_kernels(kernel_list,
+ kernel_prob,
+ kernel_size=21,
+ sigma_x_range=(0.6, 5),
+ sigma_y_range=(0.6, 5),
+ rotation_range=(-math.pi, math.pi),
+ betag_range=(0.5, 8),
+ betap_range=(0.5, 8),
+ noise_range=None,
+ return_sigma=False):
+ """Randomly generate mixed kernels.
+
+ Args:
+ kernel_list (tuple): a list name of kernel types,
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
+ 'plateau_aniso']
+ kernel_prob (tuple): corresponding kernel probability for each
+ kernel type
+ kernel_size (int):
+ sigma_x_range (tuple): [0.6, 5]
+ sigma_y_range (tuple): [0.6, 5]
+ rotation range (tuple): [-math.pi, math.pi]
+ beta_range (tuple): [0.5, 8]
+ noise_range(tuple, optional): multiplicative kernel noise,
+ [0.75, 1.25]. Default: None
+
+ Returns:
+ kernel (ndarray):
+ """
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
+ if not return_sigma:
+ if kernel_type == 'iso':
+ kernel = random_bivariate_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True, return_sigma=return_sigma)
+ elif kernel_type == 'aniso':
+ kernel = random_bivariate_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False, return_sigma=return_sigma)
+ elif kernel_type == 'generalized_iso':
+ kernel = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ betag_range,
+ noise_range=noise_range,
+ isotropic=True,
+ return_sigma=return_sigma)
+ elif kernel_type == 'generalized_aniso':
+ kernel = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ betag_range,
+ noise_range=noise_range,
+ isotropic=False,
+ return_sigma=return_sigma)
+ elif kernel_type == 'plateau_iso':
+ kernel = random_bivariate_plateau(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True, return_sigma=return_sigma)
+ elif kernel_type == 'plateau_aniso':
+ kernel = random_bivariate_plateau(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False, return_sigma=return_sigma)
+ return kernel
+ else:
+ if kernel_type == 'iso':
+ kernel, sigma_list = random_bivariate_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True, return_sigma=return_sigma)
+ elif kernel_type == 'aniso':
+ kernel, sigma_list = random_bivariate_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False, return_sigma=return_sigma)
+ elif kernel_type == 'generalized_iso':
+ kernel, sigma_list = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ betag_range,
+ noise_range=noise_range,
+ isotropic=True,
+ return_sigma=return_sigma)
+ elif kernel_type == 'generalized_aniso':
+ kernel, sigma_list = random_bivariate_generalized_Gaussian(
+ kernel_size,
+ sigma_x_range,
+ sigma_y_range,
+ rotation_range,
+ betag_range,
+ noise_range=noise_range,
+ isotropic=False,
+ return_sigma=return_sigma)
+ elif kernel_type == 'plateau_iso':
+ kernel, sigma_list = random_bivariate_plateau(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True, return_sigma=return_sigma)
+ elif kernel_type == 'plateau_aniso':
+ kernel, sigma_list = random_bivariate_plateau(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False, return_sigma=return_sigma)
+ return kernel, sigma_list
+
+
+np.seterr(divide='ignore', invalid='ignore')
+
+
+def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
+ """2D sinc filter
+
+ Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
+
+ Args:
+ cutoff (float): cutoff frequency in radians (pi is max)
+ kernel_size (int): horizontal and vertical size, must be odd.
+ pad_to (int): pad kernel size to desired size, must be odd or zero.
+ """
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
+ kernel = np.fromfunction(
+ lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
+ kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
+ kernel = kernel / np.sum(kernel)
+ if pad_to > kernel_size:
+ pad_size = (pad_to - kernel_size) // 2
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+ return kernel
+
+
+# ------------------------------------------------------------- #
+# --------------------------- noise --------------------------- #
+# ------------------------------------------------------------- #
+
+# ----------------------- Gaussian Noise ----------------------- #
+
+
+def generate_gaussian_noise(img, sigma=10, gray_noise=False):
+ """Generate Gaussian noise.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ sigma (float): Noise scale (measured in range 255). Default: 10.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ if gray_noise:
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
+ noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
+ else:
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
+ return noise
+
+
+def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
+ """Add Gaussian noise.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ sigma (float): Noise scale (measured in range 255). Default: 10.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ noise = generate_gaussian_noise(img, sigma, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
+ """Add Gaussian noise (PyTorch version).
+
+ Args:
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
+ scale (float | Tensor): Noise scale. Default: 1.0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ b, _, h, w = img.size()
+ if not isinstance(sigma, (float, int)):
+ sigma = sigma.view(img.size(0), 1, 1, 1)
+ if isinstance(gray_noise, (float, int)):
+ cal_gray_noise = gray_noise > 0
+ else:
+ gray_noise = gray_noise.view(b, 1, 1, 1)
+ cal_gray_noise = torch.sum(gray_noise) > 0
+
+ if cal_gray_noise:
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
+ noise_gray = noise_gray.view(b, 1, h, w)
+
+ # always calculate color noise
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
+
+ if cal_gray_noise:
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
+ return noise
+
+
+def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
+ """Add Gaussian noise (PyTorch version).
+
+ Args:
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
+ scale (float | Tensor): Noise scale. Default: 1.0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ----------------------- Random Gaussian Noise ----------------------- #
+def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0, return_sigma=False):
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
+ if np.random.uniform() < gray_prob:
+ gray_noise = True
+ else:
+ gray_noise = False
+ if return_sigma:
+ return generate_gaussian_noise(img, sigma, gray_noise), sigma
+ else:
+ return generate_gaussian_noise(img, sigma, gray_noise)
+
+
+def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False, return_sigma=False):
+ if return_sigma:
+ noise, sigma = random_generate_gaussian_noise(img, sigma_range, gray_prob, return_sigma=return_sigma)
+ else:
+ noise = random_generate_gaussian_noise(img, sigma_range, gray_prob, return_sigma=return_sigma)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ if return_sigma:
+ return out, sigma
+ else:
+ return out
+
+
+def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
+ sigma = torch.rand(
+ img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
+ gray_noise = (gray_noise < gray_prob).float()
+ return generate_gaussian_noise_pt(img, sigma, gray_noise)
+
+
+def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+# ----------------------- Poisson (Shot) Noise ----------------------- #
+
+
+def generate_poisson_noise(img, scale=1.0, gray_noise=False):
+ """Generate poisson noise.
+
+ Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ scale (float): Noise scale. Default: 1.0.
+ gray_noise (bool): Whether generate gray noise. Default: False.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ if gray_noise:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # round and clip image for counting vals correctly
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = len(np.unique(img))
+ vals = 2**np.ceil(np.log2(vals))
+ out = np.float32(np.random.poisson(img * vals) / float(vals))
+ noise = out - img
+ if gray_noise:
+ noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
+ return noise * scale
+
+
+def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
+ """Add poisson noise.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ scale (float): Noise scale. Default: 1.0.
+ gray_noise (bool): Whether generate gray noise. Default: False.
+
+ Returns:
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ noise = generate_poisson_noise(img, scale, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
+ """Generate a batch of poisson noise (PyTorch version)
+
+ Args:
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
+ Default: 1.0.
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
+ 0 for False, 1 for True. Default: 0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ b, _, h, w = img.size()
+ if isinstance(gray_noise, (float, int)):
+ cal_gray_noise = gray_noise > 0
+ else:
+ gray_noise = gray_noise.view(b, 1, 1, 1)
+ cal_gray_noise = torch.sum(gray_noise) > 0
+ if cal_gray_noise:
+ img_gray = rgb_to_grayscale(img, num_output_channels=1)
+ # round and clip image for counting vals correctly
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
+ # use for-loop to get the unique values for each sample
+ vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+ vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
+ out = torch.poisson(img_gray * vals) / vals
+ noise_gray = out - img_gray
+ noise_gray = noise_gray.expand(b, 3, h, w)
+
+ # always calculate color noise
+ # round and clip image for counting vals correctly
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
+ # use for-loop to get the unique values for each sample
+ vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+ vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
+ out = torch.poisson(img * vals) / vals
+ noise = out - img
+ if cal_gray_noise:
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
+ if not isinstance(scale, (float, int)):
+ scale = scale.view(b, 1, 1, 1)
+ return noise * scale
+
+
+def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
+ """Add poisson noise to a batch of images (PyTorch version).
+
+ Args:
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
+ Default: 1.0.
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
+ 0 for False, 1 for True. Default: 0.
+
+ Returns:
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
+ float32.
+ """
+ noise = generate_poisson_noise_pt(img, scale, gray_noise)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+# ----------------------- Random Poisson (Shot) Noise ----------------------- #
+
+
+def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
+ scale = np.random.uniform(scale_range[0], scale_range[1])
+ if np.random.uniform() < gray_prob:
+ gray_noise = True
+ else:
+ gray_noise = False
+ return generate_poisson_noise(img, scale, gray_noise)
+
+
+def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_poisson_noise(img, scale_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = np.clip(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+
+def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
+ scale = torch.rand(
+ img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
+ gray_noise = (gray_noise < gray_prob).float()
+ return generate_poisson_noise_pt(img, scale, gray_noise)
+
+
+def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
+ noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
+ out = img + noise
+ if clip and rounds:
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ elif clip:
+ out = torch.clamp(out, 0, 1)
+ elif rounds:
+ out = (out * 255.0).round() / 255.
+ return out
+
+# ----------------------- Random speckle Noise ----------------------- #
+
+def random_add_speckle_noise(imgs, speckle_std):
+ std_range = speckle_std
+ std_l = std_range[0]
+ std_r = std_range[1]
+ mean=0
+ std=random.uniform(std_l/255.,std_r/255.)
+
+ outputs = []
+ for img in imgs:
+ gauss=np.random.normal(loc=mean,scale=std,size=img.shape)
+ noisy=img+gauss*img
+ noisy=np.clip(noisy,0,1).astype(np.float32)
+
+ outputs.append(noisy)
+
+ return outputs
+
+
+def random_add_speckle_noise_pt(img, speckle_std):
+ std_range = speckle_std
+ std_l = std_range[0]
+ std_r = std_range[1]
+ mean=0
+ std=random.uniform(std_l/255.,std_r/255.)
+ gauss=torch.normal(mean=mean,std=std,size=img.size()).to(img.device)
+ noisy=img+gauss*img
+ noisy=torch.clamp(noisy,0,1)
+ return noisy
+
+# ----------------------- Random saltpepper Noise ----------------------- #
+
+def random_add_saltpepper_noise(imgs, saltpepper_amount, saltpepper_svsp):
+ p_range = saltpepper_amount
+ p = random.uniform(p_range[0], p_range[1])
+ q_range = saltpepper_svsp
+ q = random.uniform(q_range[0], q_range[1])
+
+ outputs = []
+ for img in imgs:
+ out = img.copy()
+ flipped = np.random.choice([True, False], size=img.shape,
+ p=[p, 1 - p])
+ salted = np.random.choice([True, False], size=img.shape,
+ p=[q, 1 - q])
+ peppered = ~salted
+ out[flipped & salted] = 1
+ out[flipped & peppered] = 0.
+ noisy = np.clip(out, 0, 1).astype(np.float32)
+
+ outputs.append(noisy)
+
+ return outputs
+
+def random_add_saltpepper_noise_pt(imgs, saltpepper_amount, saltpepper_svsp):
+ p_range = saltpepper_amount
+ p = random.uniform(p_range[0], p_range[1])
+ q_range = saltpepper_svsp
+ q = random.uniform(q_range[0], q_range[1])
+
+ imgs = imgs.permute(0,2,3,1)
+
+ outputs = []
+ for i in range(imgs.size(0)):
+ img = imgs[i]
+ out = img.clone()
+ flipped = np.random.choice([True, False], size=img.shape,
+ p=[p, 1 - p])
+ salted = np.random.choice([True, False], size=img.shape,
+ p=[q, 1 - q])
+ peppered = ~salted
+ temp = flipped & salted
+ out[flipped & salted] = 1
+ out[flipped & peppered] = 0.
+ noisy = torch.clamp(out, 0, 1)
+
+ outputs.append(noisy.permute(2,0,1))
+ if len(outputs)>1:
+ return torch.cat(outputs, dim=0)
+ else:
+ return outputs[0].unsqueeze(0)
+
+# ----------------------- Random screen Noise ----------------------- #
+
+def random_add_screen_noise(imgs, linewidth, space):
+ #screen_noise = np.random.uniform() < self.params['noise_prob'][0]
+ linewidth = linewidth
+ linewidth = int(np.random.uniform(linewidth[0], linewidth[1]))
+ space = space
+ space = int(np.random.uniform(space[0], space[1]))
+ center_color = [213,230,230] # RGB
+ outputs = []
+ for img in imgs:
+ noise = img.copy()
+
+ tmp_mask = np.zeros((img.shape[1], img.shape[0]), dtype=np.float32)
+ for i in range(0, img.shape[0], int((space+linewidth))):
+ tmp_mask[:, i:(i+linewidth)] = 1
+ colour_masks = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.float32)
+ colour_masks[:,:,0] = (center_color[0] + np.random.uniform(-20, 20))/255.
+ colour_masks[:,:,1] = (center_color[1] + np.random.uniform(0, 20))/255.
+ colour_masks[:,:,2] = (center_color[2] + np.random.uniform(0, 20))/255.
+ noise_color = cv2.addWeighted(noise, 0.6, colour_masks, 0.4, 0.0)
+ noise = noise*(1-(tmp_mask[:,:,np.newaxis])) + noise_color*(tmp_mask[:,:,np.newaxis])
+
+ outputs.append(noise)
+
+ return outputs
+
+
+# ------------------------------------------------------------------------ #
+# --------------------------- JPEG compression --------------------------- #
+# ------------------------------------------------------------------------ #
+
+
+def add_jpg_compression(img, quality=90):
+ """Add JPG compression artifacts.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ quality (float): JPG compression quality. 0 for lowest quality, 100 for
+ best quality. Default: 90.
+
+ Returns:
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ img = np.clip(img, 0, 1)
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)]
+ _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.
+ return img
+
+
+def random_add_jpg_compression(img, quality_range=(90, 100), return_q=False):
+ """Randomly add JPG compression artifacts.
+
+ Args:
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
+ quality_range (tuple[float] | list[float]): JPG compression quality
+ range. 0 for lowest quality, 100 for best quality.
+ Default: (90, 100).
+
+ Returns:
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
+ float32.
+ """
+ quality = np.random.uniform(quality_range[0], quality_range[1])
+ if return_q:
+ return add_jpg_compression(img, quality), quality
+ else:
+ return add_jpg_compression(img, quality)
diff --git a/StableSR/basicsr/data/ffhq_dataset.py b/StableSR/basicsr/data/ffhq_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..23992eb877f6b7b46cf5f40ed3667fc10916269b
--- /dev/null
+++ b/StableSR/basicsr/data/ffhq_dataset.py
@@ -0,0 +1,80 @@
+import random
+import time
+from os import path as osp
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class FFHQDataset(data.Dataset):
+ """FFHQ dataset for StyleGAN.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ io_backend (dict): IO backend type and other kwarg.
+ mean (list | tuple): Image mean.
+ std (list | tuple): Image std.
+ use_hflip (bool): Whether to horizontally flip.
+
+ """
+
+ def __init__(self, opt):
+ super(FFHQDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+
+ self.gt_folder = opt['dataroot_gt']
+ self.mean = opt['mean']
+ self.std = opt['std']
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = self.gt_folder
+ if not self.gt_folder.endswith('.lmdb'):
+ raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
+ self.paths = [line.split('.')[0] for line in fin]
+ else:
+ # FFHQ has 70000 images in total
+ self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # load gt image
+ gt_path = self.paths[index]
+ # avoid errors caused by high latency in reading files
+ retry = 3
+ while retry > 0:
+ try:
+ img_bytes = self.file_client.get(gt_path)
+ except Exception as e:
+ logger = get_root_logger()
+ logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
+ # change another file to read
+ index = random.randint(0, self.__len__())
+ gt_path = self.paths[index]
+ time.sleep(1) # sleep 1s for occasional server congestion
+ else:
+ break
+ finally:
+ retry -= 1
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # random horizontal flip
+ img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
+ # normalize
+ normalize(img_gt, self.mean, self.std, inplace=True)
+ return {'gt': img_gt, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/StableSR/basicsr/data/ffhq_degradation_dataset.py b/StableSR/basicsr/data/ffhq_degradation_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ddbc70cb9c0edc14880e78969273502ba27a4d
--- /dev/null
+++ b/StableSR/basicsr/data/ffhq_degradation_dataset.py
@@ -0,0 +1,231 @@
+import cv2
+import math
+import numpy as np
+import os.path as osp
+import torch
+import torch.utils.data as data
+import random
+from basicsr.data import degradations as degradations
+from basicsr.data.data_util import paths_from_folder
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+from pathlib import Path
+from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
+ normalize)
+
+@DATASET_REGISTRY.register()
+class FFHQDegradationDataset(data.Dataset):
+ """FFHQ dataset for GFPGAN.
+ It reads high resolution images, and then generate low-quality (LQ) images on-the-fly.
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ io_backend (dict): IO backend type and other kwarg.
+ mean (list | tuple): Image mean.
+ std (list | tuple): Image std.
+ use_hflip (bool): Whether to horizontally flip.
+ Please see more options in the codes.
+ """
+
+ def __init__(self, opt):
+ super(FFHQDegradationDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ if 'image_type' not in opt:
+ opt['image_type'] = 'png'
+
+ self.gt_folder = opt['dataroot_gt']
+ self.mean = opt['mean']
+ self.std = opt['std']
+ self.out_size = opt['out_size']
+
+ self.crop_components = opt.get('crop_components', False) # facial components
+ self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions
+
+ if self.crop_components:
+ # load component list from a pre-process pth files
+ self.components_list = torch.load(opt.get('component_path'))
+
+ # file client (lmdb io backend)
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = self.gt_folder
+ if not self.gt_folder.endswith('.lmdb'):
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
+ self.paths = [line.split('.')[0] for line in fin]
+ else:
+ # disk backend: scan file list from a folder
+ self.paths = self.paths = sorted([str(x) for x in Path(self.gt_folder).glob('*.'+opt['image_type'])])
+
+ # degradation configurations
+ self.blur_kernel_size = opt['blur_kernel_size']
+ self.kernel_list = opt['kernel_list']
+ self.kernel_prob = opt['kernel_prob']
+ self.blur_sigma = opt['blur_sigma']
+ self.downsample_range = opt['downsample_range']
+ self.noise_range = opt['noise_range']
+ self.jpeg_range = opt['jpeg_range']
+
+ # color jitter
+ self.color_jitter_prob = opt.get('color_jitter_prob')
+ self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob')
+ self.color_jitter_shift = opt.get('color_jitter_shift', 20)
+ # to gray
+ self.gray_prob = opt.get('gray_prob')
+
+ logger = get_root_logger()
+ logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
+ logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
+ logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
+ logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')
+
+ if self.color_jitter_prob is not None:
+ logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')
+ if self.gray_prob is not None:
+ logger.info(f'Use random gray. Prob: {self.gray_prob}')
+ self.color_jitter_shift /= 255.
+
+ @staticmethod
+ def color_jitter(img, shift):
+ """jitter color: randomly jitter the RGB values, in numpy formats"""
+ jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
+ img = img + jitter_val
+ img = np.clip(img, 0, 1)
+ return img
+
+ @staticmethod
+ def color_jitter_pt(img, brightness, contrast, saturation, hue):
+ """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
+ fn_idx = torch.randperm(4)
+ for fn_id in fn_idx:
+ if fn_id == 0 and brightness is not None:
+ brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
+ img = adjust_brightness(img, brightness_factor)
+
+ if fn_id == 1 and contrast is not None:
+ contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
+ img = adjust_contrast(img, contrast_factor)
+
+ if fn_id == 2 and saturation is not None:
+ saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
+ img = adjust_saturation(img, saturation_factor)
+
+ if fn_id == 3 and hue is not None:
+ hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
+ img = adjust_hue(img, hue_factor)
+ return img
+
+ def get_component_coordinates(self, index, status):
+ """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file"""
+ components_bbox = self.components_list[f'{index:08d}']
+ if status[0]: # hflip
+ # exchange right and left eye
+ tmp = components_bbox['left_eye']
+ components_bbox['left_eye'] = components_bbox['right_eye']
+ components_bbox['right_eye'] = tmp
+ # modify the width coordinate
+ components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0]
+ components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0]
+ components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0]
+
+ # get coordinates
+ locations = []
+ for part in ['left_eye', 'right_eye', 'mouth']:
+ mean = components_bbox[part][0:2]
+ half_len = components_bbox[part][2]
+ if 'eye' in part:
+ half_len *= self.eye_enlarge_ratio
+ loc = np.hstack((mean - half_len + 1, mean + half_len))
+ loc = torch.from_numpy(loc).float()
+ locations.append(loc)
+ return locations
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # load gt image
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
+ gt_path = self.paths[index]
+ img_bytes = self.file_client.get(gt_path)
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # random horizontal flip
+ img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)
+ h, w, _ = img_gt.shape
+
+ # get facial component coordinates
+ if self.crop_components:
+ locations = self.get_component_coordinates(index, status)
+ loc_left_eye, loc_right_eye, loc_mouth = locations
+
+ # ------------------------ generate lq image ------------------------ #
+ # blur
+ kernel = degradations.random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ self.blur_kernel_size,
+ self.blur_sigma,
+ self.blur_sigma, [-math.pi, math.pi],
+ noise_range=None)
+ img_lq = cv2.filter2D(img_gt, -1, kernel)
+ # downsample
+ scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
+ img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
+ # noise
+ if self.noise_range is not None:
+ img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range)
+ # jpeg compression
+ if self.jpeg_range is not None:
+ img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range)
+
+ # resize to original size
+ img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
+
+ # random color jitter (only for lq)
+ if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
+ img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
+ # random to gray (only for lq)
+ if self.gray_prob and np.random.uniform() < self.gray_prob:
+ img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
+ img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
+ if self.opt.get('gt_gray'): # whether convert GT to gray images
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
+ img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+
+ # random color jitter (pytorch version) (only for lq)
+ if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
+ brightness = self.opt.get('brightness', (0.5, 1.5))
+ contrast = self.opt.get('contrast', (0.5, 1.5))
+ saturation = self.opt.get('saturation', (0, 1.5))
+ hue = self.opt.get('hue', (-0.1, 0.1))
+ img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
+
+ # round and clip
+ img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
+
+ # normalize
+ normalize(img_gt, self.mean, self.std, inplace=True)
+ normalize(img_lq, self.mean, self.std, inplace=True)
+
+ if self.crop_components:
+ return_dict = {
+ 'lq': img_lq,
+ 'gt': img_gt,
+ 'gt_path': gt_path,
+ 'loc_left_eye': loc_left_eye,
+ 'loc_right_eye': loc_right_eye,
+ 'loc_mouth': loc_mouth
+ }
+ return return_dict
+ else:
+ return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/StableSR/basicsr/data/paired_image_dataset.py b/StableSR/basicsr/data/paired_image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..41965cd159ec539aca3d60f5a5ccd84736e13d61
--- /dev/null
+++ b/StableSR/basicsr/data/paired_image_dataset.py
@@ -0,0 +1,115 @@
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file, paired_paths_from_meta_info_file_2
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+import cv2
+
+
+@DATASET_REGISTRY.register()
+class PairedImageDataset(data.Dataset):
+ """Paired image dataset for image restoration.
+
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
+
+ There are three modes:
+
+ 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
+ 2. **meta_info_file**: Use meta information file to generate paths. \
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
+ 3. **folder**: Scan folders to generate paths. The rest.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ meta_info_file (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
+ Default: '{}'.
+ gt_size (int): Cropped patched size for gt patches.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ scale (bool): Scale, which will be added automatically.
+ phase (str): 'train' or 'val'.
+ """
+
+ def __init__(self, opt):
+ super(PairedImageDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
+ if 'filename_tmpl' in opt:
+ self.filename_tmpl = opt['filename_tmpl']
+ else:
+ self.filename_tmpl = '{}'
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
+ elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
+ self.paths = paired_paths_from_meta_info_file_2([self.lq_folder, self.gt_folder], ['lq', 'gt'],
+ self.opt['meta_info_file'], self.filename_tmpl)
+ else:
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+ # image range: [0, 1], float32.
+ gt_path = self.paths[index]['gt_path']
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ lq_path = self.paths[index]['lq_path']
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ h, w = img_gt.shape[0:2]
+ # pad
+ if h < self.opt['gt_size'] or w < self.opt['gt_size']:
+ pad_h = max(0, self.opt['gt_size'] - h)
+ pad_w = max(0, self.opt['gt_size'] - w)
+ img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
+ img_lq = cv2.copyMakeBorder(img_lq, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
+
+ # augmentation for training
+ if self.opt['phase'] == 'train':
+ gt_size = self.opt['gt_size']
+ # random crop
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
+ # flip, rotation
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
+
+ # color space transform
+ if 'color' in self.opt and self.opt['color'] == 'y':
+ img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
+ img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
+
+ # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
+ # TODO: It is better to update the datasets, rather than force to crop
+ if self.opt['phase'] != 'train':
+ img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/StableSR/basicsr/data/prefetch_dataloader.py b/StableSR/basicsr/data/prefetch_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..332abd32fcb004e6892d12dc69848a4454e3c503
--- /dev/null
+++ b/StableSR/basicsr/data/prefetch_dataloader.py
@@ -0,0 +1,122 @@
+import queue as Queue
+import threading
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+ """A general prefetch generator.
+
+ Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+ Args:
+ generator: Python generator.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, generator, num_prefetch_queue):
+ threading.Thread.__init__(self)
+ self.queue = Queue.Queue(num_prefetch_queue)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def __next__(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class PrefetchDataLoader(DataLoader):
+ """Prefetch version of dataloader.
+
+ Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+ TODO:
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+ ddp.
+
+ Args:
+ num_prefetch_queue (int): Number of prefetch queue.
+ kwargs (dict): Other arguments for dataloader.
+ """
+
+ def __init__(self, num_prefetch_queue, **kwargs):
+ self.num_prefetch_queue = num_prefetch_queue
+ super(PrefetchDataLoader, self).__init__(**kwargs)
+
+ def __iter__(self):
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher():
+ """CPU prefetcher.
+
+ Args:
+ loader: Dataloader.
+ """
+
+ def __init__(self, loader):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+
+ def next(self):
+ try:
+ return next(self.loader)
+ except StopIteration:
+ return None
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher():
+ """CUDA prefetcher.
+
+ Reference: https://github.com/NVIDIA/apex/issues/304#
+
+ It may consume more GPU memory.
+
+ Args:
+ loader: Dataloader.
+ opt (dict): Options.
+ """
+
+ def __init__(self, loader, opt):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+ self.opt = opt
+ self.stream = torch.cuda.Stream()
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.preload()
+
+ def preload(self):
+ try:
+ self.batch = next(self.loader) # self.batch is a dict
+ except StopIteration:
+ self.batch = None
+ return None
+ # put tensors to gpu
+ with torch.cuda.stream(self.stream):
+ for k, v in self.batch.items():
+ if torch.is_tensor(v):
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+ def next(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ self.preload()
+ return batch
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+ self.preload()
diff --git a/StableSR/basicsr/data/realesrgan_dataset.py b/StableSR/basicsr/data/realesrgan_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b7c0603d8353f5457b0dd96f9a9a876a192d113
--- /dev/null
+++ b/StableSR/basicsr/data/realesrgan_dataset.py
@@ -0,0 +1,242 @@
+import cv2
+import math
+import numpy as np
+import os
+import os.path as osp
+import random
+import time
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
+from basicsr.data.transforms import augment
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+@DATASET_REGISTRY.register(suffix='basicsr')
+class RealESRGANDataset(data.Dataset):
+ """Modified dataset based on the dataset used for Real-ESRGAN model:
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ It loads gt (Ground-Truth) images, and augments them.
+ It also generates blur kernels and sinc kernels for generating low-quality images.
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ meta_info (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ Please see more options in the codes.
+ """
+
+ def __init__(self, opt):
+ super(RealESRGANDataset, self).__init__()
+ self.opt = opt
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ if 'crop_size' in opt:
+ self.crop_size = opt['crop_size']
+ else:
+ self.crop_size = 512
+ if 'image_type' not in opt:
+ opt['image_type'] = 'png'
+
+ # support multiple type of data: file path and meta data, remove support of lmdb
+ self.paths = []
+ if 'meta_info' in opt:
+ with open(self.opt['meta_info']) as fin:
+ paths = [line.strip().split(' ')[0] for line in fin]
+ self.paths = [v for v in paths]
+ if 'meta_num' in opt:
+ self.paths = sorted(self.paths)[:opt['meta_num']]
+ if 'gt_path' in opt:
+ if isinstance(opt['gt_path'], str):
+ self.paths.extend(sorted([str(x) for x in Path(opt['gt_path']).glob('*.'+opt['image_type'])]))
+ else:
+ self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][0]).glob('*.'+opt['image_type'])]))
+ if len(opt['gt_path']) > 1:
+ for i in range(len(opt['gt_path'])-1):
+ self.paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]).glob('*.'+opt['image_type'])]))
+ if 'imagenet_path' in opt:
+ class_list = os.listdir(opt['imagenet_path'])
+ for class_file in class_list:
+ self.paths.extend(sorted([str(x) for x in Path(os.path.join(opt['imagenet_path'], class_file)).glob('*.'+'JPEG')]))
+ if 'face_gt_path' in opt:
+ if isinstance(opt['face_gt_path'], str):
+ face_list = sorted([str(x) for x in Path(opt['face_gt_path']).glob('*.'+opt['image_type'])])
+ self.paths.extend(face_list[:opt['num_face']])
+ else:
+ face_list = sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])
+ self.paths.extend(face_list[:opt['num_face']])
+ if len(opt['face_gt_path']) > 1:
+ for i in range(len(opt['face_gt_path'])-1):
+ self.paths.extend(sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])[:opt['num_face']])
+
+ # limit number of pictures for test
+ if 'num_pic' in opt:
+ if 'val' or 'test' in opt:
+ random.shuffle(self.paths)
+ self.paths = self.paths[:opt['num_pic']]
+ else:
+ self.paths = self.paths[:opt['num_pic']]
+
+ if 'mul_num' in opt:
+ self.paths = self.paths * opt['mul_num']
+ # print('>>>>>>>>>>>>>>>>>>>>>')
+ # print(self.paths)
+
+ # blur settings for the first degradation
+ self.blur_kernel_size = opt['blur_kernel_size']
+ self.kernel_list = opt['kernel_list']
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
+ self.blur_sigma = opt['blur_sigma']
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
+
+ # blur settings for the second degradation
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
+ self.kernel_list2 = opt['kernel_list2']
+ self.kernel_prob2 = opt['kernel_prob2']
+ self.blur_sigma2 = opt['blur_sigma2']
+ self.betag_range2 = opt['betag_range2']
+ self.betap_range2 = opt['betap_range2']
+ self.sinc_prob2 = opt['sinc_prob2']
+
+ # a final sinc filter
+ self.final_sinc_prob = opt['final_sinc_prob']
+
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
+ # TODO: kernel range is now hard-coded, should be in the configure file
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
+ self.pulse_tensor[10, 10] = 1
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # -------------------------------- Load gt images -------------------------------- #
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
+ gt_path = self.paths[index]
+ # avoid errors caused by high latency in reading files
+ retry = 3
+ while retry > 0:
+ try:
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ except (IOError, OSError) as e:
+ # logger = get_root_logger()
+ # logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
+ # change another file to read
+ index = random.randint(0, self.__len__()-1)
+ gt_path = self.paths[index]
+ time.sleep(1) # sleep 1s for occasional server congestion
+ else:
+ break
+ finally:
+ retry -= 1
+ img_gt = imfrombytes(img_bytes, float32=True)
+ # filter the dataset and remove images with too low quality
+ img_size = os.path.getsize(gt_path)
+ img_size = img_size/1024
+
+ while img_gt.shape[0] * img_gt.shape[1] < 384*384 or img_size<100:
+ index = random.randint(0, self.__len__()-1)
+ gt_path = self.paths[index]
+
+ time.sleep(0.1) # sleep 1s for occasional server congestion
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ img_size = os.path.getsize(gt_path)
+ img_size = img_size/1024
+
+ # -------------------- Do augmentation for training: flip, rotation -------------------- #
+ img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
+
+ # crop or pad to 400
+ # TODO: 400 is hard-coded. You may change it accordingly
+ h, w = img_gt.shape[0:2]
+ crop_pad_size = self.crop_size
+ # pad
+ if h < crop_pad_size or w < crop_pad_size:
+ pad_h = max(0, crop_pad_size - h)
+ pad_w = max(0, crop_pad_size - w)
+ img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
+ # crop
+ if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
+ h, w = img_gt.shape[0:2]
+ # randomly choose top and left coordinates
+ top = random.randint(0, h - crop_pad_size)
+ left = random.randint(0, w - crop_pad_size)
+ # top = (h - crop_pad_size) // 2 -1
+ # left = (w - crop_pad_size) // 2 -1
+ img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
+
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.opt['sinc_prob']:
+ # this sinc filter setting is for kernels ranging from [7, 21]
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel = random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ kernel_size,
+ self.blur_sigma,
+ self.blur_sigma, [-math.pi, math.pi],
+ self.betag_range,
+ self.betap_range,
+ noise_range=None)
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.opt['sinc_prob2']:
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel2 = random_mixed_kernels(
+ self.kernel_list2,
+ self.kernel_prob2,
+ kernel_size,
+ self.blur_sigma2,
+ self.blur_sigma2, [-math.pi, math.pi],
+ self.betag_range2,
+ self.betap_range2,
+ noise_range=None)
+
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
+ if np.random.uniform() < self.opt['final_sinc_prob']:
+ kernel_size = random.choice(self.kernel_range)
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
+ else:
+ sinc_kernel = self.pulse_tensor
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
+ kernel = torch.FloatTensor(kernel)
+ kernel2 = torch.FloatTensor(kernel2)
+
+ return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
+ return return_d
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/StableSR/basicsr/data/realesrgan_paired_dataset.py b/StableSR/basicsr/data/realesrgan_paired_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0c6159d448f26fc8a256d6a9d0c51096b78fe0
--- /dev/null
+++ b/StableSR/basicsr/data/realesrgan_paired_dataset.py
@@ -0,0 +1,114 @@
+import os
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register(suffix='basicsr')
+class RealESRGANPairedDataset(data.Dataset):
+ """Paired image dataset for image restoration.
+
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
+
+ There are three modes:
+
+ 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb.
+ 2. **meta_info_file**: Use meta information file to generate paths. \
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
+ 3. **folder**: Scan folders to generate paths. The rest.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ meta_info (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
+ Default: '{}'.
+ gt_size (int): Cropped patched size for gt patches.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ scale (bool): Scale, which will be added automatically.
+ phase (str): 'train' or 'val'.
+ """
+
+ def __init__(self, opt):
+ super(RealESRGANPairedDataset, self).__init__()
+ self.opt = opt
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ # mean and std for normalizing the input images
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
+ self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
+
+ # file client (lmdb io backend)
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
+ elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
+ # disk backend with meta_info
+ # Each line in the meta_info describes the relative path to an image
+ with open(self.opt['meta_info']) as fin:
+ paths = [line.strip() for line in fin]
+ self.paths = []
+ for path in paths:
+ gt_path, lq_path = path.split(', ')
+ gt_path = os.path.join(self.gt_folder, gt_path)
+ lq_path = os.path.join(self.lq_folder, lq_path)
+ self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
+ else:
+ # disk backend
+ # it will scan the whole folder to get meta info
+ # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
+
+ if 'num_pic' in self.opt:
+ self.paths = self.paths[:self.opt['num_pic']]
+ if 'phase' not in self.opt:
+ self.opt['phase'] = 'test'
+ if 'scale' not in self.opt:
+ self.opt['scale'] = 1
+
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
+ # image range: [0, 1], float32.
+ gt_path = self.paths[index]['gt_path']
+ img_bytes = self.file_client.get(gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ lq_path = self.paths[index]['lq_path']
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ # augmentation for training
+ if self.opt['phase'] == 'train':
+ gt_size = self.opt['gt_size']
+ # random crop
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
+ # flip, rotation
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
+
+ def __len__(self):
+ return len(self.paths)
diff --git a/StableSR/basicsr/data/reds_dataset.py b/StableSR/basicsr/data/reds_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fabef1d7e80866888f3b57ecfeb4d97c93bcb5cd
--- /dev/null
+++ b/StableSR/basicsr/data/reds_dataset.py
@@ -0,0 +1,352 @@
+import numpy as np
+import random
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.flow_util import dequantize_flow
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class REDSDataset(data.Dataset):
+ """REDS dataset for training.
+
+ The keys are generated from a meta info txt file.
+ basicsr/data/meta_info/meta_info_REDS_GT.txt
+
+ Each line contains:
+ 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
+ a white space.
+ Examples:
+ 000 100 (720,1280,3)
+ 001 100 (720,1280,3)
+ ...
+
+ Key examples: "000/00000000"
+ GT (gt): Ground-Truth;
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ dataroot_flow (str, optional): Data root path for flow.
+ meta_info_file (str): Path for meta information file.
+ val_partition (str): Validation partition types. 'REDS4' or 'official'.
+ io_backend (dict): IO backend type and other kwarg.
+ num_frame (int): Window size for input frames.
+ gt_size (int): Cropped patched size for gt patches.
+ interval_list (list): Interval list for temporal augmentation.
+ random_reverse (bool): Random reverse input frames.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __init__(self, opt):
+ super(REDSDataset, self).__init__()
+ self.opt = opt
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+ self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
+ assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
+ self.num_frame = opt['num_frame']
+ self.num_half_frames = opt['num_frame'] // 2
+
+ self.keys = []
+ with open(opt['meta_info_file'], 'r') as fin:
+ for line in fin:
+ folder, frame_num, _ = line.split(' ')
+ self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
+
+ # remove the video clips used in validation
+ if opt['val_partition'] == 'REDS4':
+ val_partition = ['000', '011', '015', '020']
+ elif opt['val_partition'] == 'official':
+ val_partition = [f'{v:03d}' for v in range(240, 270)]
+ else:
+ raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
+ f"Supported ones are ['official', 'REDS4'].")
+ self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ if self.flow_root is not None:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
+ else:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+ # temporal augmentation configs
+ self.interval_list = opt['interval_list']
+ self.random_reverse = opt['random_reverse']
+ interval_str = ','.join(str(x) for x in opt['interval_list'])
+ logger = get_root_logger()
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
+ f'random reverse is {self.random_reverse}.')
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip_name, frame_name = key.split('/') # key example: 000/00000000
+ center_frame_idx = int(frame_name)
+
+ # determine the neighboring frames
+ interval = random.choice(self.interval_list)
+
+ # ensure not exceeding the borders
+ start_frame_idx = center_frame_idx - self.num_half_frames * interval
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
+ # each clip has 100 frames starting from 0 to 99
+ while (start_frame_idx < 0) or (end_frame_idx > 99):
+ center_frame_idx = random.randint(0, 99)
+ start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
+ end_frame_idx = center_frame_idx + self.num_half_frames * interval
+ frame_name = f'{center_frame_idx:08d}'
+ neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ neighbor_list.reverse()
+
+ assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
+
+ # get the GT frame (as the center frame)
+ if self.is_lmdb:
+ img_gt_path = f'{clip_name}/{frame_name}'
+ else:
+ img_gt_path = self.gt_root / clip_name / f'{frame_name}.png'
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # get the neighboring LQ frames
+ img_lqs = []
+ for neighbor in neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip_name}/{neighbor:08d}'
+ else:
+ img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ img_lqs.append(img_lq)
+
+ # get flows
+ if self.flow_root is not None:
+ img_flows = []
+ # read previous flows
+ for i in range(self.num_half_frames, 0, -1):
+ if self.is_lmdb:
+ flow_path = f'{clip_name}/{frame_name}_p{i}'
+ else:
+ flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
+ img_bytes = self.file_client.get(flow_path, 'flow')
+ cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
+ dx, dy = np.split(cat_flow, 2, axis=0)
+ flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
+ img_flows.append(flow)
+ # read next flows
+ for i in range(1, self.num_half_frames + 1):
+ if self.is_lmdb:
+ flow_path = f'{clip_name}/{frame_name}_n{i}'
+ else:
+ flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
+ img_bytes = self.file_client.get(flow_path, 'flow')
+ cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
+ dx, dy = np.split(cat_flow, 2, axis=0)
+ flow = dequantize_flow(dx, dy, max_val=20, denorm=False) # we use max_val 20 here.
+ img_flows.append(flow)
+
+ # for random crop, here, img_flows and img_lqs have the same
+ # spatial size
+ img_lqs.extend(img_flows)
+
+ # randomly crop
+ img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
+ if self.flow_root is not None:
+ img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
+
+ # augmentation - flip, rotate
+ img_lqs.append(img_gt)
+ if self.flow_root is not None:
+ img_results, img_flows = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'], img_flows)
+ else:
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_lqs = torch.stack(img_results[0:-1], dim=0)
+ img_gt = img_results[-1]
+
+ if self.flow_root is not None:
+ img_flows = img2tensor(img_flows)
+ # add the zero center flow
+ img_flows.insert(self.num_half_frames, torch.zeros_like(img_flows[0]))
+ img_flows = torch.stack(img_flows, dim=0)
+
+ # img_lqs: (t, c, h, w)
+ # img_flows: (t, 2, h, w)
+ # img_gt: (c, h, w)
+ # key: str
+ if self.flow_root is not None:
+ return {'lq': img_lqs, 'flow': img_flows, 'gt': img_gt, 'key': key}
+ else:
+ return {'lq': img_lqs, 'gt': img_gt, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
+
+
+@DATASET_REGISTRY.register()
+class REDSRecurrentDataset(data.Dataset):
+ """REDS dataset for training recurrent networks.
+
+ The keys are generated from a meta info txt file.
+ basicsr/data/meta_info/meta_info_REDS_GT.txt
+
+ Each line contains:
+ 1. subfolder (clip) name; 2. frame number; 3. image shape, separated by
+ a white space.
+ Examples:
+ 000 100 (720,1280,3)
+ 001 100 (720,1280,3)
+ ...
+
+ Key examples: "000/00000000"
+ GT (gt): Ground-Truth;
+ LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ dataroot_flow (str, optional): Data root path for flow.
+ meta_info_file (str): Path for meta information file.
+ val_partition (str): Validation partition types. 'REDS4' or 'official'.
+ io_backend (dict): IO backend type and other kwarg.
+ num_frame (int): Window size for input frames.
+ gt_size (int): Cropped patched size for gt patches.
+ interval_list (list): Interval list for temporal augmentation.
+ random_reverse (bool): Random reverse input frames.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __init__(self, opt):
+ super(REDSRecurrentDataset, self).__init__()
+ self.opt = opt
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+ self.num_frame = opt['num_frame']
+
+ self.keys = []
+ with open(opt['meta_info_file'], 'r') as fin:
+ for line in fin:
+ folder, frame_num, _ = line.split(' ')
+ self.keys.extend([f'{folder}/{i:08d}' for i in range(int(frame_num))])
+
+ # remove the video clips used in validation
+ if opt['val_partition'] == 'REDS4':
+ val_partition = ['000', '011', '015', '020']
+ elif opt['val_partition'] == 'official':
+ val_partition = [f'{v:03d}' for v in range(240, 270)]
+ else:
+ raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
+ f"Supported ones are ['official', 'REDS4'].")
+ if opt['test_mode']:
+ self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
+ else:
+ self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ if hasattr(self, 'flow_root') and self.flow_root is not None:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root, self.flow_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt', 'flow']
+ else:
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+ # temporal augmentation configs
+ self.interval_list = opt.get('interval_list', [1])
+ self.random_reverse = opt.get('random_reverse', False)
+ interval_str = ','.join(str(x) for x in self.interval_list)
+ logger = get_root_logger()
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
+ f'random reverse is {self.random_reverse}.')
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip_name, frame_name = key.split('/') # key example: 000/00000000
+
+ # determine the neighboring frames
+ interval = random.choice(self.interval_list)
+
+ # ensure not exceeding the borders
+ start_frame_idx = int(frame_name)
+ if start_frame_idx > 100 - self.num_frame * interval:
+ start_frame_idx = random.randint(0, 100 - self.num_frame * interval)
+ end_frame_idx = start_frame_idx + self.num_frame * interval
+
+ neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
+
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ neighbor_list.reverse()
+
+ # get the neighboring LQ and GT frames
+ img_lqs = []
+ img_gts = []
+ for neighbor in neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip_name}/{neighbor:08d}'
+ img_gt_path = f'{clip_name}/{neighbor:08d}'
+ else:
+ img_lq_path = self.lq_root / clip_name / f'{neighbor:08d}.png'
+ img_gt_path = self.gt_root / clip_name / f'{neighbor:08d}.png'
+
+ # get LQ
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ img_lqs.append(img_lq)
+
+ # get GT
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+ img_gts.append(img_gt)
+
+ # randomly crop
+ img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
+
+ # augmentation - flip, rotate
+ img_lqs.extend(img_gts)
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
+ img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
+
+ # img_lqs: (t, c, h, w)
+ # img_gts: (t, c, h, w)
+ # key: str
+ return {'lq': img_lqs, 'gt': img_gts, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
diff --git a/StableSR/basicsr/data/single_image_dataset.py b/StableSR/basicsr/data/single_image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8d1a94d1723fb832b0c6fc897e72e0081c4a399
--- /dev/null
+++ b/StableSR/basicsr/data/single_image_dataset.py
@@ -0,0 +1,164 @@
+from os import path as osp
+from torch.utils import data as data
+from torchvision.transforms.functional import normalize
+
+from basicsr.data.data_util import paths_from_lmdb
+from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
+from basicsr.utils.registry import DATASET_REGISTRY
+
+from pathlib import Path
+import random
+import cv2
+import numpy as np
+import torch
+
+@DATASET_REGISTRY.register()
+class SingleImageDataset(data.Dataset):
+ """Read only lq images in the test phase.
+
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
+
+ There are two modes:
+ 1. 'meta_info_file': Use meta information file to generate paths.
+ 2. 'folder': Scan folders to generate paths.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ dataroot_lq (str): Data root path for lq.
+ meta_info_file (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ """
+
+ def __init__(self, opt):
+ super(SingleImageDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+ self.lq_folder = opt['dataroot_lq']
+
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.io_backend_opt['db_paths'] = [self.lq_folder]
+ self.io_backend_opt['client_keys'] = ['lq']
+ self.paths = paths_from_lmdb(self.lq_folder)
+ elif 'meta_info_file' in self.opt:
+ with open(self.opt['meta_info_file'], 'r') as fin:
+ self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
+ else:
+ self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # load lq image
+ lq_path = self.paths[index]
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ # color space transform
+ if 'color' in self.opt and self.opt['color'] == 'y':
+ img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ return {'lq': img_lq, 'lq_path': lq_path}
+
+ def __len__(self):
+ return len(self.paths)
+
+@DATASET_REGISTRY.register()
+class SingleImageNPDataset(data.Dataset):
+ """Read only lq images in the test phase.
+
+ Read diffusion generated data for training CFW.
+
+ Args:
+ opt (dict): Config for train datasets. It contains the following keys:
+ gt_path: Data root path for training data. The path needs to contain the following folders:
+ gts: Ground-truth images.
+ inputs: Input LQ images.
+ latents: The corresponding HQ latent code generated by diffusion model given the input LQ image.
+ samples: The corresponding HQ image given the HQ latent code, just for verification.
+ io_backend (dict): IO backend type and other kwarg.
+ """
+
+ def __init__(self, opt):
+ super(SingleImageNPDataset, self).__init__()
+ self.opt = opt
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.mean = opt['mean'] if 'mean' in opt else None
+ self.std = opt['std'] if 'std' in opt else None
+ if 'image_type' not in opt:
+ opt['image_type'] = 'png'
+
+ if isinstance(opt['gt_path'], str):
+ self.gt_paths = sorted([str(x) for x in Path(opt['gt_path']+'/gts').glob('*.'+opt['image_type'])])
+ self.lq_paths = sorted([str(x) for x in Path(opt['gt_path']+'/inputs').glob('*.'+opt['image_type'])])
+ self.np_paths = sorted([str(x) for x in Path(opt['gt_path']+'/latents').glob('*.npy')])
+ self.sample_paths = sorted([str(x) for x in Path(opt['gt_path']+'/samples').glob('*.'+opt['image_type'])])
+ else:
+ self.gt_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/gts').glob('*.'+opt['image_type'])])
+ self.lq_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/inputs').glob('*.'+opt['image_type'])])
+ self.np_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/latents').glob('*.npy')])
+ self.sample_paths = sorted([str(x) for x in Path(opt['gt_path'][0]+'/samples').glob('*.'+opt['image_type'])])
+ if len(opt['gt_path']) > 1:
+ for i in range(len(opt['gt_path'])-1):
+ self.gt_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/gts').glob('*.'+opt['image_type'])]))
+ self.lq_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/inputs').glob('*.'+opt['image_type'])]))
+ self.np_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/latents').glob('*.npy')]))
+ self.sample_paths.extend(sorted([str(x) for x in Path(opt['gt_path'][i+1]+'/samples').glob('*.'+opt['image_type'])]))
+
+ assert len(self.gt_paths) == len(self.lq_paths)
+ assert len(self.gt_paths) == len(self.np_paths)
+ assert len(self.gt_paths) == len(self.sample_paths)
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # load lq image
+ lq_path = self.lq_paths[index]
+ gt_path = self.gt_paths[index]
+ sample_path = self.sample_paths[index]
+ np_path = self.np_paths[index]
+
+ img_bytes = self.file_client.get(lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+
+ img_bytes_gt = self.file_client.get(gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes_gt, float32=True)
+
+ img_bytes_sample = self.file_client.get(sample_path, 'sample')
+ img_sample = imfrombytes(img_bytes_sample, float32=True)
+
+ latent_np = np.load(np_path)
+
+ # color space transform
+ if 'color' in self.opt and self.opt['color'] == 'y':
+ img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
+ img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None]
+ img_sample = rgb2ycbcr(img_sample, y_only=True)[..., None]
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
+ img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
+ img_sample = img2tensor(img_sample, bgr2rgb=True, float32=True)
+ latent_np = torch.from_numpy(latent_np).float()
+ latent_np = latent_np.to(img_gt.device)
+ # normalize
+ if self.mean is not None or self.std is not None:
+ normalize(img_lq, self.mean, self.std, inplace=True)
+ normalize(img_gt, self.mean, self.std, inplace=True)
+ normalize(img_sample, self.mean, self.std, inplace=True)
+ return {'lq': img_lq, 'lq_path': lq_path, 'gt': img_gt, 'gt_path': gt_path, 'latent': latent_np[0], 'latent_path': np_path, 'sample': img_sample, 'sample_path': sample_path}
+
+ def __len__(self):
+ return len(self.gt_paths)
diff --git a/StableSR/basicsr/data/transforms.py b/StableSR/basicsr/data/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c700a399bb737a2286ea705fcebd937e6fb54ca7
--- /dev/null
+++ b/StableSR/basicsr/data/transforms.py
@@ -0,0 +1,240 @@
+import cv2
+import random
+import torch
+
+
+def mod_crop(img, scale):
+ """Mod crop images, used during testing.
+
+ Args:
+ img (ndarray): Input image.
+ scale (int): Scale factor.
+
+ Returns:
+ ndarray: Result image.
+ """
+ img = img.copy()
+ if img.ndim in (2, 3):
+ h, w = img.shape[0], img.shape[1]
+ h_remainder, w_remainder = h % scale, w % scale
+ img = img[:h - h_remainder, :w - w_remainder, ...]
+ else:
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
+ return img
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
+ """Paired random crop. Support Numpy array and Tensor inputs.
+
+ It crops lists of lq and gt images with corresponding locations.
+
+ Args:
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ gt_patch_size (int): GT patch size.
+ scale (int): Scale factor.
+ gt_path (str): Path to ground-truth. Default: None.
+
+ Returns:
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
+ only have one element, just return ndarray.
+ """
+
+ if not isinstance(img_gts, list):
+ img_gts = [img_gts]
+ if not isinstance(img_lqs, list):
+ img_lqs = [img_lqs]
+
+ # determine input type: Numpy array or Tensor
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
+
+ if input_type == 'Tensor':
+ h_lq, w_lq = img_lqs[0].size()[-2:]
+ h_gt, w_gt = img_gts[0].size()[-2:]
+ else:
+ h_lq, w_lq = img_lqs[0].shape[0:2]
+ h_gt, w_gt = img_gts[0].shape[0:2]
+ lq_patch_size = gt_patch_size // scale
+
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+ f'({lq_patch_size}, {lq_patch_size}). '
+ f'Please remove {gt_path}.')
+
+ # randomly choose top and left coordinates for lq patch
+ top = random.randint(0, h_lq - lq_patch_size)
+ left = random.randint(0, w_lq - lq_patch_size)
+
+ # crop lq patch
+ if input_type == 'Tensor':
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
+ else:
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+ # crop corresponding gt patch
+ top_gt, left_gt = int(top * scale), int(left * scale)
+ if input_type == 'Tensor':
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
+ else:
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+ if len(img_gts) == 1:
+ img_gts = img_gts[0]
+ if len(img_lqs) == 1:
+ img_lqs = img_lqs[0]
+ return img_gts, img_lqs
+
+def triplet_random_crop(img_gts, img_lqs, img_segs, gt_patch_size, scale, gt_path=None):
+
+ if not isinstance(img_gts, list):
+ img_gts = [img_gts]
+ if not isinstance(img_lqs, list):
+ img_lqs = [img_lqs]
+ if not isinstance(img_segs, list):
+ img_segs = [img_segs]
+
+ # determine input type: Numpy array or Tensor
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
+
+ if input_type == 'Tensor':
+ h_lq, w_lq = img_lqs[0].size()[-2:]
+ h_gt, w_gt = img_gts[0].size()[-2:]
+ h_seg, w_seg = img_segs[0].size()[-2:]
+ else:
+ h_lq, w_lq = img_lqs[0].shape[0:2]
+ h_gt, w_gt = img_gts[0].shape[0:2]
+ h_seg, w_seg = img_segs[0].shape[0:2]
+ lq_patch_size = gt_patch_size // scale
+
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+ f'({lq_patch_size}, {lq_patch_size}). '
+ f'Please remove {gt_path}.')
+
+ # randomly choose top and left coordinates for lq patch
+ top = random.randint(0, h_lq - lq_patch_size)
+ left = random.randint(0, w_lq - lq_patch_size)
+
+ # crop lq patch
+ if input_type == 'Tensor':
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
+ else:
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+ # crop corresponding gt patch
+ top_gt, left_gt = int(top * scale), int(left * scale)
+ if input_type == 'Tensor':
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
+ else:
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+
+ if input_type == 'Tensor':
+ img_segs = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_segs]
+ else:
+ img_segs = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_segs]
+
+ if len(img_gts) == 1:
+ img_gts = img_gts[0]
+ if len(img_lqs) == 1:
+ img_lqs = img_lqs[0]
+ if len(img_segs) == 1:
+ img_segs = img_segs[0]
+
+ return img_gts, img_lqs, img_segs
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+ We use vertical flip and transpose for rotation implementation.
+ All the images in the list use the same augmentation.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+ is an ndarray, it will be transformed to a list.
+ hflip (bool): Horizontal flip. Default: True.
+ rotation (bool): Ratotation. Default: True.
+ flows (list[ndarray]: Flows to be augmented. If the input is an
+ ndarray, it will be transformed to a list.
+ Dimension is (h, w, 2). Default: None.
+ return_status (bool): Return the status of flip and rotation.
+ Default: False.
+
+ Returns:
+ list[ndarray] | ndarray: Augmented images and flows. If returned
+ results only have one element, just return ndarray.
+
+ """
+ hflip = hflip and random.random() < 0.5
+ vflip = rotation and random.random() < 0.5
+ rot90 = rotation and random.random() < 0.5
+
+ def _augment(img):
+ if hflip: # horizontal
+ cv2.flip(img, 1, img)
+ if vflip: # vertical
+ cv2.flip(img, 0, img)
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ def _augment_flow(flow):
+ if hflip: # horizontal
+ cv2.flip(flow, 1, flow)
+ flow[:, :, 0] *= -1
+ if vflip: # vertical
+ cv2.flip(flow, 0, flow)
+ flow[:, :, 1] *= -1
+ if rot90:
+ flow = flow.transpose(1, 0, 2)
+ flow = flow[:, :, [1, 0]]
+ return flow
+
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ imgs = [_augment(img) for img in imgs]
+ if len(imgs) == 1:
+ imgs = imgs[0]
+
+ if flows is not None:
+ if not isinstance(flows, list):
+ flows = [flows]
+ flows = [_augment_flow(flow) for flow in flows]
+ if len(flows) == 1:
+ flows = flows[0]
+ return imgs, flows
+ else:
+ if return_status:
+ return imgs, (hflip, vflip, rot90)
+ else:
+ return imgs
+
+
+def img_rotate(img, angle, center=None, scale=1.0):
+ """Rotate image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees. Positive values mean
+ counter-clockwise rotation.
+ center (tuple[int]): Rotation center. If the center is None,
+ initialize it as the center of the image. Default: None.
+ scale (float): Isotropic scale factor. Default: 1.0.
+ """
+ (h, w) = img.shape[:2]
+
+ if center is None:
+ center = (w // 2, h // 2)
+
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
+ return rotated_img
diff --git a/StableSR/basicsr/data/video_test_dataset.py b/StableSR/basicsr/data/video_test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..929f7d97472a0eb810e33e694d5362a6749ab4b6
--- /dev/null
+++ b/StableSR/basicsr/data/video_test_dataset.py
@@ -0,0 +1,283 @@
+import glob
+import torch
+from os import path as osp
+from torch.utils import data as data
+
+from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VideoTestDataset(data.Dataset):
+ """Video test dataset.
+
+ Supported datasets: Vid4, REDS4, REDSofficial.
+ More generally, it supports testing dataset with following structures:
+
+ ::
+
+ dataroot
+ ├── subfolder1
+ ├── frame000
+ ├── frame001
+ ├── ...
+ ├── subfolder2
+ ├── frame000
+ ├── frame001
+ ├── ...
+ ├── ...
+
+ For testing datasets, there is no need to prepare LMDB files.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ io_backend (dict): IO backend type and other kwarg.
+ cache_data (bool): Whether to cache testing datasets.
+ name (str): Dataset name.
+ meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
+ in the dataroot will be used.
+ num_frame (int): Window size for input frames.
+ padding (str): Padding mode.
+ """
+
+ def __init__(self, opt):
+ super(VideoTestDataset, self).__init__()
+ self.opt = opt
+ self.cache_data = opt['cache_data']
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
+ self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
+
+ logger = get_root_logger()
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
+ self.imgs_lq, self.imgs_gt = {}, {}
+ if 'meta_info_file' in opt:
+ with open(opt['meta_info_file'], 'r') as fin:
+ subfolders = [line.split(' ')[0] for line in fin]
+ subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
+ subfolders_gt = [osp.join(self.gt_root, key) for key in subfolders]
+ else:
+ subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
+ subfolders_gt = sorted(glob.glob(osp.join(self.gt_root, '*')))
+
+ if opt['name'].lower() in ['vid4', 'reds4', 'redsofficial']:
+ for subfolder_lq, subfolder_gt in zip(subfolders_lq, subfolders_gt):
+ # get frame list for lq and gt
+ subfolder_name = osp.basename(subfolder_lq)
+ img_paths_lq = sorted(list(scandir(subfolder_lq, full_path=True)))
+ img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
+
+ max_idx = len(img_paths_lq)
+ assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
+ f' and gt folders ({len(img_paths_gt)})')
+
+ self.data_info['lq_path'].extend(img_paths_lq)
+ self.data_info['gt_path'].extend(img_paths_gt)
+ self.data_info['folder'].extend([subfolder_name] * max_idx)
+ for i in range(max_idx):
+ self.data_info['idx'].append(f'{i}/{max_idx}')
+ border_l = [0] * max_idx
+ for i in range(self.opt['num_frame'] // 2):
+ border_l[i] = 1
+ border_l[max_idx - i - 1] = 1
+ self.data_info['border'].extend(border_l)
+
+ # cache data or save the frame list
+ if self.cache_data:
+ logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
+ self.imgs_lq[subfolder_name] = read_img_seq(img_paths_lq)
+ self.imgs_gt[subfolder_name] = read_img_seq(img_paths_gt)
+ else:
+ self.imgs_lq[subfolder_name] = img_paths_lq
+ self.imgs_gt[subfolder_name] = img_paths_gt
+ else:
+ raise ValueError(f'Non-supported video test dataset: {type(opt["name"])}')
+
+ def __getitem__(self, index):
+ folder = self.data_info['folder'][index]
+ idx, max_idx = self.data_info['idx'][index].split('/')
+ idx, max_idx = int(idx), int(max_idx)
+ border = self.data_info['border'][index]
+ lq_path = self.data_info['lq_path'][index]
+
+ select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
+
+ if self.cache_data:
+ imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
+ img_gt = self.imgs_gt[folder][idx]
+ else:
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
+ imgs_lq = read_img_seq(img_paths_lq)
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]])
+ img_gt.squeeze_(0)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': folder, # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
+ 'border': border, # 1 for border, 0 for non-border
+ 'lq_path': lq_path # center frame
+ }
+
+ def __len__(self):
+ return len(self.data_info['gt_path'])
+
+
+@DATASET_REGISTRY.register()
+class VideoTestVimeo90KDataset(data.Dataset):
+ """Video test dataset for Vimeo90k-Test dataset.
+
+ It only keeps the center frame for testing.
+ For testing datasets, there is no need to prepare LMDB files.
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ io_backend (dict): IO backend type and other kwarg.
+ cache_data (bool): Whether to cache testing datasets.
+ name (str): Dataset name.
+ meta_info_file (str): The path to the file storing the list of test folders. If not provided, all the folders
+ in the dataroot will be used.
+ num_frame (int): Window size for input frames.
+ padding (str): Padding mode.
+ """
+
+ def __init__(self, opt):
+ super(VideoTestVimeo90KDataset, self).__init__()
+ self.opt = opt
+ self.cache_data = opt['cache_data']
+ if self.cache_data:
+ raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
+ self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
+ self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
+ neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ assert self.io_backend_opt['type'] != 'lmdb', 'No need to use lmdb during validation/test.'
+
+ logger = get_root_logger()
+ logger.info(f'Generate data info for VideoTestDataset - {opt["name"]}')
+ with open(opt['meta_info_file'], 'r') as fin:
+ subfolders = [line.split(' ')[0] for line in fin]
+ for idx, subfolder in enumerate(subfolders):
+ gt_path = osp.join(self.gt_root, subfolder, 'im4.png')
+ self.data_info['gt_path'].append(gt_path)
+ lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
+ self.data_info['lq_path'].append(lq_paths)
+ self.data_info['folder'].append('vimeo90k')
+ self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
+ self.data_info['border'].append(0)
+
+ def __getitem__(self, index):
+ lq_path = self.data_info['lq_path'][index]
+ gt_path = self.data_info['gt_path'][index]
+ imgs_lq = read_img_seq(lq_path)
+ img_gt = read_img_seq([gt_path])
+ img_gt.squeeze_(0)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': self.data_info['folder'][index], # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/843
+ 'border': self.data_info['border'][index], # 0 for non-border
+ 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
+ }
+
+ def __len__(self):
+ return len(self.data_info['gt_path'])
+
+
+@DATASET_REGISTRY.register()
+class VideoTestDUFDataset(VideoTestDataset):
+ """ Video test dataset for DUF dataset.
+
+ Args:
+ opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
+ It has the following extra keys:
+ use_duf_downsampling (bool): Whether to use duf downsampling to generate low-resolution frames.
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __getitem__(self, index):
+ folder = self.data_info['folder'][index]
+ idx, max_idx = self.data_info['idx'][index].split('/')
+ idx, max_idx = int(idx), int(max_idx)
+ border = self.data_info['border'][index]
+ lq_path = self.data_info['lq_path'][index]
+
+ select_idx = generate_frame_indices(idx, max_idx, self.opt['num_frame'], padding=self.opt['padding'])
+
+ if self.cache_data:
+ if self.opt['use_duf_downsampling']:
+ # read imgs_gt to generate low-resolution frames
+ imgs_lq = self.imgs_gt[folder].index_select(0, torch.LongTensor(select_idx))
+ imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
+ else:
+ imgs_lq = self.imgs_lq[folder].index_select(0, torch.LongTensor(select_idx))
+ img_gt = self.imgs_gt[folder][idx]
+ else:
+ if self.opt['use_duf_downsampling']:
+ img_paths_lq = [self.imgs_gt[folder][i] for i in select_idx]
+ # read imgs_gt to generate low-resolution frames
+ imgs_lq = read_img_seq(img_paths_lq, require_mod_crop=True, scale=self.opt['scale'])
+ imgs_lq = duf_downsample(imgs_lq, kernel_size=13, scale=self.opt['scale'])
+ else:
+ img_paths_lq = [self.imgs_lq[folder][i] for i in select_idx]
+ imgs_lq = read_img_seq(img_paths_lq)
+ img_gt = read_img_seq([self.imgs_gt[folder][idx]], require_mod_crop=True, scale=self.opt['scale'])
+ img_gt.squeeze_(0)
+
+ return {
+ 'lq': imgs_lq, # (t, c, h, w)
+ 'gt': img_gt, # (c, h, w)
+ 'folder': folder, # folder name
+ 'idx': self.data_info['idx'][index], # e.g., 0/99
+ 'border': border, # 1 for border, 0 for non-border
+ 'lq_path': lq_path # center frame
+ }
+
+
+@DATASET_REGISTRY.register()
+class VideoRecurrentTestDataset(VideoTestDataset):
+ """Video test dataset for recurrent architectures, which takes LR video
+ frames as input and output corresponding HR video frames.
+
+ Args:
+ opt (dict): Same as VideoTestDataset. Unused opt:
+ padding (str): Padding mode.
+
+ """
+
+ def __init__(self, opt):
+ super(VideoRecurrentTestDataset, self).__init__(opt)
+ # Find unique folder strings
+ self.folders = sorted(list(set(self.data_info['folder'])))
+
+ def __getitem__(self, index):
+ folder = self.folders[index]
+
+ if self.cache_data:
+ imgs_lq = self.imgs_lq[folder]
+ imgs_gt = self.imgs_gt[folder]
+ else:
+ raise NotImplementedError('Without cache_data is not implemented.')
+
+ return {
+ 'lq': imgs_lq,
+ 'gt': imgs_gt,
+ 'folder': folder,
+ }
+
+ def __len__(self):
+ return len(self.folders)
diff --git a/StableSR/basicsr/data/vimeo90k_dataset.py b/StableSR/basicsr/data/vimeo90k_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5e33e1082667aeee61fecf2436fb287e82e0936
--- /dev/null
+++ b/StableSR/basicsr/data/vimeo90k_dataset.py
@@ -0,0 +1,199 @@
+import random
+import torch
+from pathlib import Path
+from torch.utils import data as data
+
+from basicsr.data.transforms import augment, paired_random_crop
+from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
+from basicsr.utils.registry import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class Vimeo90KDataset(data.Dataset):
+ """Vimeo90K dataset for training.
+
+ The keys are generated from a meta info txt file.
+ basicsr/data/meta_info/meta_info_Vimeo90K_train_GT.txt
+
+ Each line contains the following items, separated by a white space.
+
+ 1. clip name;
+ 2. frame number;
+ 3. image shape
+
+ Examples:
+
+ ::
+
+ 00001/0001 7 (256,448,3)
+ 00001/0002 7 (256,448,3)
+
+ - Key examples: "00001/0001"
+ - GT (gt): Ground-Truth;
+ - LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
+
+ The neighboring frame list for different num_frame:
+
+ ::
+
+ num_frame | frame list
+ 1 | 4
+ 3 | 3,4,5
+ 5 | 2,3,4,5,6
+ 7 | 1,2,3,4,5,6,7
+
+ Args:
+ opt (dict): Config for train dataset. It contains the following keys:
+ dataroot_gt (str): Data root path for gt.
+ dataroot_lq (str): Data root path for lq.
+ meta_info_file (str): Path for meta information file.
+ io_backend (dict): IO backend type and other kwarg.
+ num_frame (int): Window size for input frames.
+ gt_size (int): Cropped patched size for gt patches.
+ random_reverse (bool): Random reverse input frames.
+ use_hflip (bool): Use horizontal flips.
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
+ scale (bool): Scale, which will be added automatically.
+ """
+
+ def __init__(self, opt):
+ super(Vimeo90KDataset, self).__init__()
+ self.opt = opt
+ self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
+
+ with open(opt['meta_info_file'], 'r') as fin:
+ self.keys = [line.split(' ')[0] for line in fin]
+
+ # file client (io backend)
+ self.file_client = None
+ self.io_backend_opt = opt['io_backend']
+ self.is_lmdb = False
+ if self.io_backend_opt['type'] == 'lmdb':
+ self.is_lmdb = True
+ self.io_backend_opt['db_paths'] = [self.lq_root, self.gt_root]
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
+
+ # indices of input images
+ self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
+
+ # temporal augmentation configs
+ self.random_reverse = opt['random_reverse']
+ logger = get_root_logger()
+ logger.info(f'Random reverse is {self.random_reverse}.')
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ self.neighbor_list.reverse()
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip, seq = key.split('/') # key example: 00001/0001
+
+ # get the GT frame (im4.png)
+ if self.is_lmdb:
+ img_gt_path = f'{key}/im4'
+ else:
+ img_gt_path = self.gt_root / clip / seq / 'im4.png'
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ # get the neighboring LQ frames
+ img_lqs = []
+ for neighbor in self.neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip}/{seq}/im{neighbor}'
+ else:
+ img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ img_lqs.append(img_lq)
+
+ # randomly crop
+ img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
+
+ # augmentation - flip, rotate
+ img_lqs.append(img_gt)
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_lqs = torch.stack(img_results[0:-1], dim=0)
+ img_gt = img_results[-1]
+
+ # img_lqs: (t, c, h, w)
+ # img_gt: (c, h, w)
+ # key: str
+ return {'lq': img_lqs, 'gt': img_gt, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
+
+
+@DATASET_REGISTRY.register()
+class Vimeo90KRecurrentDataset(Vimeo90KDataset):
+
+ def __init__(self, opt):
+ super(Vimeo90KRecurrentDataset, self).__init__(opt)
+
+ self.flip_sequence = opt['flip_sequence']
+ self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
+
+ def __getitem__(self, index):
+ if self.file_client is None:
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
+
+ # random reverse
+ if self.random_reverse and random.random() < 0.5:
+ self.neighbor_list.reverse()
+
+ scale = self.opt['scale']
+ gt_size = self.opt['gt_size']
+ key = self.keys[index]
+ clip, seq = key.split('/') # key example: 00001/0001
+
+ # get the neighboring LQ and GT frames
+ img_lqs = []
+ img_gts = []
+ for neighbor in self.neighbor_list:
+ if self.is_lmdb:
+ img_lq_path = f'{clip}/{seq}/im{neighbor}'
+ img_gt_path = f'{clip}/{seq}/im{neighbor}'
+ else:
+ img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
+ img_gt_path = self.gt_root / clip / seq / f'im{neighbor}.png'
+ # LQ
+ img_bytes = self.file_client.get(img_lq_path, 'lq')
+ img_lq = imfrombytes(img_bytes, float32=True)
+ # GT
+ img_bytes = self.file_client.get(img_gt_path, 'gt')
+ img_gt = imfrombytes(img_bytes, float32=True)
+
+ img_lqs.append(img_lq)
+ img_gts.append(img_gt)
+
+ # randomly crop
+ img_gts, img_lqs = paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
+
+ # augmentation - flip, rotate
+ img_lqs.extend(img_gts)
+ img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
+
+ img_results = img2tensor(img_results)
+ img_lqs = torch.stack(img_results[:7], dim=0)
+ img_gts = torch.stack(img_results[7:], dim=0)
+
+ if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
+ img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
+ img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
+
+ # img_lqs: (t, c, h, w)
+ # img_gt: (c, h, w)
+ # key: str
+ return {'lq': img_lqs, 'gt': img_gts, 'key': key}
+
+ def __len__(self):
+ return len(self.keys)
diff --git a/StableSR/basicsr/losses/__init__.py b/StableSR/basicsr/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..70a172aeed5b388ae102466eb1f02d40ba30e9b4
--- /dev/null
+++ b/StableSR/basicsr/losses/__init__.py
@@ -0,0 +1,31 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import LOSS_REGISTRY
+from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty
+
+__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize']
+
+# automatically scan and import loss modules for registry
+# scan all the files under the 'losses' folder and collect files ending with '_loss.py'
+loss_folder = osp.dirname(osp.abspath(__file__))
+loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
+# import all the loss modules
+_model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames]
+
+
+def build_loss(opt):
+ """Build loss from options.
+
+ Args:
+ opt (dict): Configuration. It must contain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ loss_type = opt.pop('type')
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
+ return loss
diff --git a/StableSR/basicsr/losses/basic_loss.py b/StableSR/basicsr/losses/basic_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2e965526a9b0e2686575bf93f0173cc2664d9bb
--- /dev/null
+++ b/StableSR/basicsr/losses/basic_loss.py
@@ -0,0 +1,253 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.archs.vgg_arch import VGGFeatureExtractor
+from basicsr.utils.registry import LOSS_REGISTRY
+from .loss_util import weighted_loss
+
+_reduction_modes = ['none', 'mean', 'sum']
+
+
+@weighted_loss
+def l1_loss(pred, target):
+ return F.l1_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def mse_loss(pred, target):
+ return F.mse_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def charbonnier_loss(pred, target, eps=1e-12):
+ return torch.sqrt((pred - target)**2 + eps)
+
+
+@LOSS_REGISTRY.register()
+class L1Loss(nn.Module):
+ """L1 (mean absolute error, MAE) loss.
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(L1Loss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
+ """
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class MSELoss(nn.Module):
+ """MSE (L2) loss.
+
+ Args:
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(MSELoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
+ """
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class CharbonnierLoss(nn.Module):
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
+ variant of L1Loss).
+
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
+ Super-Resolution".
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ eps (float): A value used to control the curvature near zero. Default: 1e-12.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
+ super(CharbonnierLoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+ self.eps = eps
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
+ """
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class WeightedTVLoss(L1Loss):
+ """Weighted TV loss.
+
+ Args:
+ loss_weight (float): Loss weight. Default: 1.0.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ if reduction not in ['mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum')
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction)
+
+ def forward(self, pred, weight=None):
+ if weight is None:
+ y_weight = None
+ x_weight = None
+ else:
+ y_weight = weight[:, :, :-1, :]
+ x_weight = weight[:, :, :, :-1]
+
+ y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
+ x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
+
+ loss = x_diff + y_diff
+
+ return loss
+
+
+@LOSS_REGISTRY.register()
+class PerceptualLoss(nn.Module):
+ """Perceptual loss with commonly used style loss.
+
+ Args:
+ layer_weights (dict): The weight for each layer of vgg feature.
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
+ feature layer (before relu5_4) will be extracted with weight
+ 1.0 in calculating losses.
+ vgg_type (str): The type of vgg network used as feature extractor.
+ Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image in vgg.
+ Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+ loss will be calculated and the loss will multiplied by the
+ weight. Default: 1.0.
+ style_weight (float): If `style_weight > 0`, the style loss will be
+ calculated and the loss will multiplied by the weight.
+ Default: 0.
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+ """
+
+ def __init__(self,
+ layer_weights,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ perceptual_weight=1.0,
+ style_weight=0.,
+ criterion='l1'):
+ super(PerceptualLoss, self).__init__()
+ self.perceptual_weight = perceptual_weight
+ self.style_weight = style_weight
+ self.layer_weights = layer_weights
+ self.vgg = VGGFeatureExtractor(
+ layer_name_list=list(layer_weights.keys()),
+ vgg_type=vgg_type,
+ use_input_norm=use_input_norm,
+ range_norm=range_norm)
+
+ self.criterion_type = criterion
+ if self.criterion_type == 'l1':
+ self.criterion = torch.nn.L1Loss()
+ elif self.criterion_type == 'l2':
+ self.criterion = torch.nn.L2loss()
+ elif self.criterion_type == 'fro':
+ self.criterion = None
+ else:
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
+
+ def forward(self, x, gt):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ # extract vgg features
+ x_features = self.vgg(x)
+ gt_features = self.vgg(gt.detach())
+
+ # calculate perceptual loss
+ if self.perceptual_weight > 0:
+ percep_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
+ else:
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
+ percep_loss *= self.perceptual_weight
+ else:
+ percep_loss = None
+
+ # calculate style loss
+ if self.style_weight > 0:
+ style_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ style_loss += torch.norm(
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
+ else:
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
+ gt_features[k])) * self.layer_weights[k]
+ style_loss *= self.style_weight
+ else:
+ style_loss = None
+
+ return percep_loss, style_loss
+
+ def _gram_mat(self, x):
+ """Calculate Gram matrix.
+
+ Args:
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+ Returns:
+ torch.Tensor: Gram matrix.
+ """
+ n, c, h, w = x.size()
+ features = x.view(n, c, w * h)
+ features_t = features.transpose(1, 2)
+ gram = features.bmm(features_t) / (c * h * w)
+ return gram
diff --git a/StableSR/basicsr/losses/gan_loss.py b/StableSR/basicsr/losses/gan_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..870baa2227b79eab29a3141a216b4b614e2bcdf3
--- /dev/null
+++ b/StableSR/basicsr/losses/gan_loss.py
@@ -0,0 +1,207 @@
+import math
+import torch
+from torch import autograd as autograd
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import LOSS_REGISTRY
+
+
+@LOSS_REGISTRY.register()
+class GANLoss(nn.Module):
+ """Define GAN loss.
+
+ Args:
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+ real_label_val (float): The value for real label. Default: 1.0.
+ fake_label_val (float): The value for fake label. Default: 0.0.
+ loss_weight (float): Loss weight. Default: 1.0.
+ Note that loss_weight is only for generators; and it is always 1.0
+ for discriminators.
+ """
+
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+ super(GANLoss, self).__init__()
+ self.gan_type = gan_type
+ self.loss_weight = loss_weight
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif self.gan_type == 'wgan':
+ self.loss = self._wgan_loss
+ elif self.gan_type == 'wgan_softplus':
+ self.loss = self._wgan_softplus_loss
+ elif self.gan_type == 'hinge':
+ self.loss = nn.ReLU()
+ else:
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
+
+ def _wgan_loss(self, input, target):
+ """wgan loss.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return -input.mean() if target else input.mean()
+
+ def _wgan_softplus_loss(self, input, target):
+ """wgan loss with soft plus. softplus is a smooth approximation to the
+ ReLU function.
+
+ In StyleGAN2, it is called:
+ Logistic loss for discriminator;
+ Non-saturating loss for generator.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
+
+ def get_target_label(self, input, target_is_real):
+ """Get target label.
+
+ Args:
+ input (Tensor): Input tensor.
+ target_is_real (bool): Whether the target is real or fake.
+
+ Returns:
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+ return Tensor.
+ """
+
+ if self.gan_type in ['wgan', 'wgan_softplus']:
+ return target_is_real
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
+ return input.new_ones(input.size()) * target_val
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ Args:
+ input (Tensor): The input for the loss module, i.e., the network
+ prediction.
+ target_is_real (bool): Whether the targe is real or fake.
+ is_disc (bool): Whether the loss for discriminators or not.
+ Default: False.
+
+ Returns:
+ Tensor: GAN loss value.
+ """
+ target_label = self.get_target_label(input, target_is_real)
+ if self.gan_type == 'hinge':
+ if is_disc: # for discriminators in hinge-gan
+ input = -input if target_is_real else input
+ loss = self.loss(1 + input).mean()
+ else: # for generators in hinge-gan
+ loss = -input.mean()
+ else: # other gan types
+ loss = self.loss(input, target_label)
+
+ # loss_weight is always 1.0 for discriminators
+ return loss if is_disc else loss * self.loss_weight
+
+
+@LOSS_REGISTRY.register()
+class MultiScaleGANLoss(GANLoss):
+ """
+ MultiScaleGANLoss accepts a list of predictions
+ """
+
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+ super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ The input is a list of tensors, or a list of (a list of tensors)
+ """
+ if isinstance(input, list):
+ loss = 0
+ for pred_i in input:
+ if isinstance(pred_i, list):
+ # Only compute GAN loss for the last layer
+ # in case of multiscale feature matching
+ pred_i = pred_i[-1]
+ # Safe operation: 0-dim tensor calling self.mean() does nothing
+ loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
+ loss += loss_tensor
+ return loss / len(input)
+ else:
+ return super().forward(input, target_is_real, is_disc)
+
+
+def r1_penalty(real_pred, real_img):
+ """R1 regularization for discriminator. The core idea is to
+ penalize the gradient on real data alone: when the
+ generator distribution produces the true data distribution
+ and the discriminator is equal to 0 on the data manifold, the
+ gradient penalty ensures that the discriminator cannot create
+ a non-zero gradient orthogonal to the data manifold without
+ suffering a loss in the GAN game.
+
+ Reference: Eq. 9 in Which training methods for GANs do actually converge.
+ """
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
+ return grad_penalty
+
+
+def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
+
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
+
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
+
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
+
+
+def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
+ """Calculate gradient penalty for wgan-gp.
+
+ Args:
+ discriminator (nn.Module): Network for the discriminator.
+ real_data (Tensor): Real input data.
+ fake_data (Tensor): Fake input data.
+ weight (Tensor): Weight tensor. Default: None.
+
+ Returns:
+ Tensor: A tensor for gradient penalty.
+ """
+
+ batch_size = real_data.size(0)
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
+
+ # interpolate between real_data and fake_data
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
+
+ disc_interpolates = discriminator(interpolates)
+ gradients = autograd.grad(
+ outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones_like(disc_interpolates),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+
+ if weight is not None:
+ gradients = gradients * weight
+
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
+ if weight is not None:
+ gradients_penalty /= torch.mean(weight)
+
+ return gradients_penalty
diff --git a/StableSR/basicsr/losses/loss_util.py b/StableSR/basicsr/losses/loss_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd293ff9e6a22814e5aeff6ae11fb54d2e4bafff
--- /dev/null
+++ b/StableSR/basicsr/losses/loss_util.py
@@ -0,0 +1,145 @@
+import functools
+import torch
+from torch.nn import functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are 'none', 'mean' and 'sum'.
+
+ Returns:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ else:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean'):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights. Default: None.
+ reduction (str): Same as built-in losses of PyTorch. Options are
+ 'none', 'mean' and 'sum'. Default: 'mean'.
+
+ Returns:
+ Tensor: Loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if weight is not specified or reduction is sum, just reduce the loss
+ if weight is None or reduction == 'sum':
+ loss = reduce_loss(loss, reduction)
+ # if reduction is mean, then compute mean over weight region
+ elif reduction == 'mean':
+ if weight.size(1) > 1:
+ weight = weight.sum()
+ else:
+ weight = weight.sum() * loss.size(1)
+ loss = loss.sum() / weight
+
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.5000)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, reduction='sum')
+ tensor(3.)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction)
+ return loss
+
+ return wrapper
+
+
+def get_local_weights(residual, ksize):
+ """Get local weights for generating the artifact map of LDL.
+
+ It is only called by the `get_refined_artifact_map` function.
+
+ Args:
+ residual (Tensor): Residual between predicted and ground truth images.
+ ksize (Int): size of the local window.
+
+ Returns:
+ Tensor: weight for each pixel to be discriminated as an artifact pixel
+ """
+
+ pad = (ksize - 1) // 2
+ residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
+
+ unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
+ pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
+
+ return pixel_level_weight
+
+
+def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
+ """Calculate the artifact map of LDL
+ (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022)
+
+ Args:
+ img_gt (Tensor): ground truth images.
+ img_output (Tensor): output images given by the optimizing model.
+ img_ema (Tensor): output images given by the ema model.
+ ksize (Int): size of the local window.
+
+ Returns:
+ overall_weight: weight for each pixel to be discriminated as an artifact pixel
+ (calculated based on both local and global observations).
+ """
+
+ residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
+ residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
+
+ patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
+ pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
+ overall_weight = patch_level_weight * pixel_level_weight
+
+ overall_weight[residual_sr < residual_ema] = 0
+
+ return overall_weight
diff --git a/StableSR/basicsr/metrics/README.md b/StableSR/basicsr/metrics/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..98d00308ab79e92a2393f9759190de8122a8e79d
--- /dev/null
+++ b/StableSR/basicsr/metrics/README.md
@@ -0,0 +1,48 @@
+# Metrics
+
+[English](README.md) **|** [简体中文](README_CN.md)
+
+- [约定](#约定)
+- [PSNR 和 SSIM](#psnr-和-ssim)
+
+## 约定
+
+因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
+
+- Numpy 类型 (一般是 cv2 的结果)
+ - UINT8: BGR, [0, 255], (h, w, c)
+ - float: BGR, [0, 1], (h, w, c). 一般作为中间结果
+- Tensor 类型
+ - float: RGB, [0, 1], (n, c, h, w)
+
+其他约定:
+
+- 以 `_pt` 结尾的是 PyTorch 结果
+- PyTorch version 支持 batch 计算
+- 颜色转换在 float32 上做;metric计算在 float64 上做
+
+## PSNR 和 SSIM
+
+PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
+在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378))
+
+下面列了各个实现的结果比对.
+总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
+
+- PSNR 比对
+
+|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
+|:---| :---: | :---: | :---: | :---: | :---: |
+|baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 |
+|baboon| Y | - |22.441898 | 22.441899 | 22.444916|
+|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
+|comic | Y | - | 21.720398 | 21.720398 | 21.721663|
+
+- SSIM 比对
+
+|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
+|:---| :---: | :---: | :---: | :---: | :---: |
+|baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 |
+|baboon| Y | - |0.453097| 0.453097 | 0.453171|
+|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
+|comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
diff --git a/StableSR/basicsr/metrics/README_CN.md b/StableSR/basicsr/metrics/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..98d00308ab79e92a2393f9759190de8122a8e79d
--- /dev/null
+++ b/StableSR/basicsr/metrics/README_CN.md
@@ -0,0 +1,48 @@
+# Metrics
+
+[English](README.md) **|** [简体中文](README_CN.md)
+
+- [约定](#约定)
+- [PSNR 和 SSIM](#psnr-和-ssim)
+
+## 约定
+
+因为不同的输入类型会导致结果的不同,因此我们对输入做如下约定:
+
+- Numpy 类型 (一般是 cv2 的结果)
+ - UINT8: BGR, [0, 255], (h, w, c)
+ - float: BGR, [0, 1], (h, w, c). 一般作为中间结果
+- Tensor 类型
+ - float: RGB, [0, 1], (n, c, h, w)
+
+其他约定:
+
+- 以 `_pt` 结尾的是 PyTorch 结果
+- PyTorch version 支持 batch 计算
+- 颜色转换在 float32 上做;metric计算在 float64 上做
+
+## PSNR 和 SSIM
+
+PSNR 和 SSIM 的结果趋势是一致的,即一般 PSNR 高,则 SSIM 也高。
+在实现上, PSNR 的各种实现都很一致。SSIM 有各种各样的实现,我们这里和 MATLAB 最原始版本保持 (参考 [NTIRE17比赛](https://competitions.codalab.org/competitions/16306#participate) 的 [evaluation代码](https://competitions.codalab.org/my/datasets/download/ebe960d8-0ec8-4846-a1a2-7c4a586a7378))
+
+下面列了各个实现的结果比对.
+总结:PyTorch 实现和 MATLAB 实现基本一致,在 GPU 运行上会有稍许差异
+
+- PSNR 比对
+
+|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
+|:---| :---: | :---: | :---: | :---: | :---: |
+|baboon| RGB | 20.419710 | 20.419710 | 20.419710 |20.419710 |
+|baboon| Y | - |22.441898 | 22.441899 | 22.444916|
+|comic | RGB | 20.239912 | 20.239912 | 20.239912 | 20.239912 |
+|comic | Y | - | 21.720398 | 21.720398 | 21.721663|
+
+- SSIM 比对
+
+|Image | Color Space | MATLAB | Numpy | PyTorch CPU | PyTorch GPU |
+|:---| :---: | :---: | :---: | :---: | :---: |
+|baboon| RGB | 0.391853 | 0.391853 | 0.391853|0.391853 |
+|baboon| Y | - |0.453097| 0.453097 | 0.453171|
+|comic | RGB | 0.567738 | 0.567738 | 0.567738 | 0.567738|
+|comic | Y | - | 0.585511 | 0.585511 | 0.585522 |
diff --git a/StableSR/basicsr/metrics/__init__.py b/StableSR/basicsr/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..330f3c863f66a98d41942c6995837283265d94ef
--- /dev/null
+++ b/StableSR/basicsr/metrics/__init__.py
@@ -0,0 +1,20 @@
+from copy import deepcopy
+
+from basicsr.utils.registry import METRIC_REGISTRY
+from .niqe import calculate_niqe
+from .psnr_ssim import calculate_psnr, calculate_ssim, calculate_ssim_pt, calculate_psnr_pt
+
+__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
+
+
+def calculate_metric(data, opt):
+ """Calculate metric from data and options.
+
+ Args:
+ opt (dict): Configuration. It must contain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ metric_type = opt.pop('type')
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
+ return metric
diff --git a/StableSR/basicsr/metrics/fid.py b/StableSR/basicsr/metrics/fid.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b0ba6df1de96d93a60c1cfd3dc1fcf4d3d31533
--- /dev/null
+++ b/StableSR/basicsr/metrics/fid.py
@@ -0,0 +1,89 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from scipy import linalg
+from tqdm import tqdm
+
+from basicsr.archs.inception import InceptionV3
+
+
+def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False):
+ # we may not resize the input, but in [rosinality/stylegan2-pytorch] it
+ # does resize the input.
+ inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input)
+ inception = nn.DataParallel(inception).eval().to(device)
+ return inception
+
+
+@torch.no_grad()
+def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'):
+ """Extract inception features.
+
+ Args:
+ data_generator (generator): A data generator.
+ inception (nn.Module): Inception model.
+ len_generator (int): Length of the data_generator to show the
+ progressbar. Default: None.
+ device (str): Device. Default: cuda.
+
+ Returns:
+ Tensor: Extracted features.
+ """
+ if len_generator is not None:
+ pbar = tqdm(total=len_generator, unit='batch', desc='Extract')
+ else:
+ pbar = None
+ features = []
+
+ for data in data_generator:
+ if pbar:
+ pbar.update(1)
+ data = data.to(device)
+ feature = inception(data)[0].view(data.shape[0], -1)
+ features.append(feature.to('cpu'))
+ if pbar:
+ pbar.close()
+ features = torch.cat(features, 0)
+ return features
+
+
+def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
+ """Numpy implementation of the Frechet Distance.
+
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is:
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+ Stable version by Dougal J. Sutherland.
+
+ Args:
+ mu1 (np.array): The sample mean over activations.
+ sigma1 (np.array): The covariance matrix over activations for generated samples.
+ mu2 (np.array): The sample mean over activations, precalculated on an representative data set.
+ sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set.
+
+ Returns:
+ float: The Frechet Distance.
+ """
+ assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
+ assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions')
+
+ cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
+
+ # Product might be almost singular
+ if not np.isfinite(cov_sqrt).all():
+ print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates')
+ offset = np.eye(sigma1.shape[0]) * eps
+ cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(cov_sqrt):
+ if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3):
+ m = np.max(np.abs(cov_sqrt.imag))
+ raise ValueError(f'Imaginary component {m}')
+ cov_sqrt = cov_sqrt.real
+
+ mean_diff = mu1 - mu2
+ mean_norm = mean_diff @ mean_diff
+ trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt)
+ fid = mean_norm + trace
+
+ return fid
diff --git a/StableSR/basicsr/metrics/metric_util.py b/StableSR/basicsr/metrics/metric_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a27c70a043beeeb59cfaf533079492293065448
--- /dev/null
+++ b/StableSR/basicsr/metrics/metric_util.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from basicsr.utils import bgr2ycbcr
+
+
+def reorder_image(img, input_order='HWC'):
+ """Reorder images to 'HWC' order.
+
+ If the input_order is (h, w), return (h, w, 1);
+ If the input_order is (c, h, w), return (h, w, c);
+ If the input_order is (h, w, c), return as it is.
+
+ Args:
+ img (ndarray): Input image.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ If the input image shape is (h, w), input_order will not have
+ effects. Default: 'HWC'.
+
+ Returns:
+ ndarray: reordered image.
+ """
+
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'")
+ if len(img.shape) == 2:
+ img = img[..., None]
+ if input_order == 'CHW':
+ img = img.transpose(1, 2, 0)
+ return img
+
+
+def to_y_channel(img):
+ """Change to Y channel of YCbCr.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+
+ Returns:
+ (ndarray): Images with range [0, 255] (float type) without round.
+ """
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 3 and img.shape[2] == 3:
+ img = bgr2ycbcr(img, y_only=True)
+ img = img[..., None]
+ return img * 255.
diff --git a/StableSR/basicsr/metrics/niqe.py b/StableSR/basicsr/metrics/niqe.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3c1467f61d809ec3b2630073118460d9d61a861
--- /dev/null
+++ b/StableSR/basicsr/metrics/niqe.py
@@ -0,0 +1,199 @@
+import cv2
+import math
+import numpy as np
+import os
+from scipy.ndimage import convolve
+from scipy.special import gamma
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.matlab_functions import imresize
+from basicsr.utils.registry import METRIC_REGISTRY
+
+
+def estimate_aggd_param(block):
+ """Estimate AGGD (Asymmetric Generalized Gaussian Distribution) parameters.
+
+ Args:
+ block (ndarray): 2D Image block.
+
+ Returns:
+ tuple: alpha (float), beta_l (float) and beta_r (float) for the AGGD
+ distribution (Estimating the parames in Equation 7 in the paper).
+ """
+ block = block.flatten()
+ gam = np.arange(0.2, 10.001, 0.001) # len = 9801
+ gam_reciprocal = np.reciprocal(gam)
+ r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
+
+ left_std = np.sqrt(np.mean(block[block < 0]**2))
+ right_std = np.sqrt(np.mean(block[block > 0]**2))
+ gammahat = left_std / right_std
+ rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
+ rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
+ array_position = np.argmin((r_gam - rhatnorm)**2)
+
+ alpha = gam[array_position]
+ beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
+ beta_r = right_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
+ return (alpha, beta_l, beta_r)
+
+
+def compute_feature(block):
+ """Compute features.
+
+ Args:
+ block (ndarray): 2D Image block.
+
+ Returns:
+ list: Features with length of 18.
+ """
+ feat = []
+ alpha, beta_l, beta_r = estimate_aggd_param(block)
+ feat.extend([alpha, (beta_l + beta_r) / 2])
+
+ # distortions disturb the fairly regular structure of natural images.
+ # This deviation can be captured by analyzing the sample distribution of
+ # the products of pairs of adjacent coefficients computed along
+ # horizontal, vertical and diagonal orientations.
+ shifts = [[0, 1], [1, 0], [1, 1], [1, -1]]
+ for i in range(len(shifts)):
+ shifted_block = np.roll(block, shifts[i], axis=(0, 1))
+ alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block)
+ # Eq. 8
+ mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha))
+ feat.extend([alpha, mean, beta_l, beta_r])
+ return feat
+
+
+def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, block_size_w=96):
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
+
+ ``Paper: Making a "Completely Blind" Image Quality Analyzer``
+
+ This implementation could produce almost the same results as the official
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
+
+ Note that we do not include block overlap height and width, since they are
+ always 0 in the official implementation.
+
+ For good performance, it is advisable by the official implementation to
+ divide the distorted image in to the same size patched as used for the
+ construction of multivariate Gaussian model.
+
+ Args:
+ img (ndarray): Input image whose quality needs to be computed. The
+ image must be a gray or Y (of YCbCr) image with shape (h, w).
+ Range [0, 255] with float type.
+ mu_pris_param (ndarray): Mean of a pre-defined multivariate Gaussian
+ model calculated on the pristine dataset.
+ cov_pris_param (ndarray): Covariance of a pre-defined multivariate
+ Gaussian model calculated on the pristine dataset.
+ gaussian_window (ndarray): A 7x7 Gaussian window used for smoothing the
+ image.
+ block_size_h (int): Height of the blocks in to which image is divided.
+ Default: 96 (the official recommended value).
+ block_size_w (int): Width of the blocks in to which image is divided.
+ Default: 96 (the official recommended value).
+ """
+ assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
+ # crop image
+ h, w = img.shape
+ num_block_h = math.floor(h / block_size_h)
+ num_block_w = math.floor(w / block_size_w)
+ img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
+
+ distparam = [] # dist param is actually the multiscale features
+ for scale in (1, 2): # perform on two scales (1, 2)
+ mu = convolve(img, gaussian_window, mode='nearest')
+ sigma = np.sqrt(np.abs(convolve(np.square(img), gaussian_window, mode='nearest') - np.square(mu)))
+ # normalize, as in Eq. 1 in the paper
+ img_nomalized = (img - mu) / (sigma + 1)
+
+ feat = []
+ for idx_w in range(num_block_w):
+ for idx_h in range(num_block_h):
+ # process ecah block
+ block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
+ idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
+ feat.append(compute_feature(block))
+
+ distparam.append(np.array(feat))
+
+ if scale == 1:
+ img = imresize(img / 255., scale=0.5, antialiasing=True)
+ img = img * 255.
+
+ distparam = np.concatenate(distparam, axis=1)
+
+ # fit a MVG (multivariate Gaussian) model to distorted patch features
+ mu_distparam = np.nanmean(distparam, axis=0)
+ # use nancov. ref: https://ww2.mathworks.cn/help/stats/nancov.html
+ distparam_no_nan = distparam[~np.isnan(distparam).any(axis=1)]
+ cov_distparam = np.cov(distparam_no_nan, rowvar=False)
+
+ # compute niqe quality, Eq. 10 in the paper
+ invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
+ quality = np.matmul(
+ np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
+
+ quality = np.sqrt(quality)
+ quality = float(np.squeeze(quality))
+ return quality
+
+
+@METRIC_REGISTRY.register()
+def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs):
+ """Calculate NIQE (Natural Image Quality Evaluator) metric.
+
+ ``Paper: Making a "Completely Blind" Image Quality Analyzer``
+
+ This implementation could produce almost the same results as the official
+ MATLAB codes: http://live.ece.utexas.edu/research/quality/niqe_release.zip
+
+ > MATLAB R2021a result for tests/data/baboon.png: 5.72957338 (5.7296)
+ > Our re-implementation result for tests/data/baboon.png: 5.7295763 (5.7296)
+
+ We use the official params estimated from the pristine dataset.
+ We use the recommended block size (96, 96) without overlaps.
+
+ Args:
+ img (ndarray): Input image whose quality needs to be computed.
+ The input image must be in range [0, 255] with float/int type.
+ The input_order of image can be 'HW' or 'HWC' or 'CHW'. (BGR order)
+ If the input order is 'HWC' or 'CHW', it will be converted to gray
+ or Y (of YCbCr) image according to the ``convert_to`` argument.
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the metric calculation.
+ input_order (str): Whether the input order is 'HW', 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ convert_to (str): Whether converted to 'y' (of MATLAB YCbCr) or 'gray'.
+ Default: 'y'.
+
+ Returns:
+ float: NIQE result.
+ """
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
+ # we use the official params estimated from the pristine dataset.
+ niqe_pris_params = np.load(os.path.join(ROOT_DIR, 'niqe_pris_params.npz'))
+ mu_pris_param = niqe_pris_params['mu_pris_param']
+ cov_pris_param = niqe_pris_params['cov_pris_param']
+ gaussian_window = niqe_pris_params['gaussian_window']
+
+ img = img.astype(np.float32)
+ if input_order != 'HW':
+ img = reorder_image(img, input_order=input_order)
+ if convert_to == 'y':
+ img = to_y_channel(img)
+ elif convert_to == 'gray':
+ img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
+ img = np.squeeze(img)
+
+ if crop_border != 0:
+ img = img[crop_border:-crop_border, crop_border:-crop_border]
+
+ # round is necessary for being consistent with MATLAB's result
+ img = img.round()
+
+ niqe_result = niqe(img, mu_pris_param, cov_pris_param, gaussian_window)
+
+ return niqe_result
diff --git a/StableSR/basicsr/metrics/niqe_pris_params.npz b/StableSR/basicsr/metrics/niqe_pris_params.npz
new file mode 100644
index 0000000000000000000000000000000000000000..42f06a9a18e6ed8bbf7933bec1477b189ef798de
--- /dev/null
+++ b/StableSR/basicsr/metrics/niqe_pris_params.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2a7c182a68c9e7f1b2e2e5ec723279d6f65d912b6fcaf37eb2bf03d7367c4296
+size 11850
diff --git a/StableSR/basicsr/metrics/psnr_ssim.py b/StableSR/basicsr/metrics/psnr_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab03113f89805c990ff22795601274bf45db23a1
--- /dev/null
+++ b/StableSR/basicsr/metrics/psnr_ssim.py
@@ -0,0 +1,231 @@
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.color_util import rgb2ycbcr_pt
+from basicsr.utils.registry import METRIC_REGISTRY
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+ Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: PSNR result.
+ """
+
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
+ img = reorder_image(img, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+
+ if crop_border != 0:
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img = to_y_channel(img)
+ img2 = to_y_channel(img2)
+
+ img = img.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ mse = np.mean((img - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 10. * np.log10(255. * 255. / mse)
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
+
+ Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: PSNR result.
+ """
+
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+
+ if crop_border != 0:
+ img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
+ img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
+
+ if test_y_channel:
+ img = rgb2ycbcr_pt(img, y_only=True)
+ img2 = rgb2ycbcr_pt(img2, y_only=True)
+
+ img = img.to(torch.float64)
+ img2 = img2.to(torch.float64)
+
+ mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
+ return 10. * torch.log10(1. / (mse + 1e-8))
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
+ """Calculate SSIM (structural similarity).
+
+ ``Paper: Image quality assessment: From error visibility to structural similarity``
+
+ The results are the same as that of the official released MATLAB code in
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+ For three-channel images, SSIM is calculated for each channel and then
+ averaged.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: SSIM result.
+ """
+
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
+ img = reorder_image(img, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+
+ if crop_border != 0:
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img = to_y_channel(img)
+ img2 = to_y_channel(img2)
+
+ img = img.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ ssims = []
+ for i in range(img.shape[2]):
+ ssims.append(_ssim(img[..., i], img2[..., i]))
+ return np.array(ssims).mean()
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
+ """Calculate SSIM (structural similarity) (PyTorch version).
+
+ ``Paper: Image quality assessment: From error visibility to structural similarity``
+
+ The results are the same as that of the official released MATLAB code in
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+ For three-channel images, SSIM is calculated for each channel and then
+ averaged.
+
+ Args:
+ img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: SSIM result.
+ """
+
+ assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+
+ if crop_border != 0:
+ img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
+ img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]
+
+ if test_y_channel:
+ img = rgb2ycbcr_pt(img, y_only=True)
+ img2 = rgb2ycbcr_pt(img2, y_only=True)
+
+ img = img.to(torch.float64)
+ img2 = img2.to(torch.float64)
+
+ ssim = _ssim_pth(img * 255., img2 * 255.)
+ return ssim
+
+
+def _ssim(img, img2):
+ """Calculate SSIM (structural similarity) for one channel images.
+
+ It is called by func:`calculate_ssim`.
+
+ Args:
+ img (ndarray): Images with range [0, 255] with order 'HWC'.
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
+
+ Returns:
+ float: SSIM result.
+ """
+
+ c1 = (0.01 * 255)**2
+ c2 = (0.03 * 255)**2
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
+ return ssim_map.mean()
+
+
+def _ssim_pth(img, img2):
+ """Calculate SSIM (structural similarity) (PyTorch version).
+
+ It is called by func:`calculate_ssim_pt`.
+
+ Args:
+ img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+ img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
+
+ Returns:
+ float: SSIM result.
+ """
+ c1 = (0.01 * 255)**2
+ c2 = (0.03 * 255)**2
+
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+ window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device)
+
+ mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode
+ mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq
+ sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2
+
+ cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2)
+ ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map
+ return ssim_map.mean([1, 2, 3])
diff --git a/StableSR/basicsr/metrics/test_metrics/test_psnr_ssim.py b/StableSR/basicsr/metrics/test_metrics/test_psnr_ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b05a73a0e38e89b2321ddc9415123a92f5c5a4
--- /dev/null
+++ b/StableSR/basicsr/metrics/test_metrics/test_psnr_ssim.py
@@ -0,0 +1,52 @@
+import cv2
+import torch
+
+from basicsr.metrics import calculate_psnr, calculate_ssim
+from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt
+from basicsr.utils import img2tensor
+
+
+def test(img_path, img_path2, crop_border, test_y_channel=False):
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+ img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED)
+
+ # --------------------- Numpy ---------------------
+ psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel)
+ ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel)
+ print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
+
+ # --------------------- PyTorch (CPU) ---------------------
+ img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
+ img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
+
+ psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
+ ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
+ print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}')
+
+ # --------------------- PyTorch (GPU) ---------------------
+ img = img.cuda()
+ img2 = img2.cuda()
+ psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
+ ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
+ print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}')
+
+ psnr_pth = calculate_psnr_pt(
+ torch.repeat_interleave(img, 2, dim=0),
+ torch.repeat_interleave(img2, 2, dim=0),
+ crop_border=crop_border,
+ test_y_channel=test_y_channel)
+ ssim_pth = calculate_ssim_pt(
+ torch.repeat_interleave(img, 2, dim=0),
+ torch.repeat_interleave(img2, 2, dim=0),
+ crop_border=crop_border,
+ test_y_channel=test_y_channel)
+ print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,'
+ f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}')
+
+
+if __name__ == '__main__':
+ test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False)
+ test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True)
+
+ test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False)
+ test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True)
diff --git a/StableSR/basicsr/models/__init__.py b/StableSR/basicsr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..85796deae014c20a9aa600133468d04900c4fb89
--- /dev/null
+++ b/StableSR/basicsr/models/__init__.py
@@ -0,0 +1,29 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import MODEL_REGISTRY
+
+__all__ = ['build_model']
+
+# automatically scan and import model modules for registry
+# scan all the files under the 'models' folder and collect files ending with '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
+# import all the model modules
+_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
+
+
+def build_model(opt):
+ """Build model from options.
+
+ Args:
+ opt (dict): Configuration. It must contain:
+ model_type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
+ logger = get_root_logger()
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
+ return model
diff --git a/StableSR/basicsr/models/base_model.py b/StableSR/basicsr/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbf8229f59dee86a7f9f95c1d07da785fb5f15b3
--- /dev/null
+++ b/StableSR/basicsr/models/base_model.py
@@ -0,0 +1,392 @@
+import os
+import time
+import torch
+from collections import OrderedDict
+from copy import deepcopy
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from basicsr.models import lr_scheduler as lr_scheduler
+from basicsr.utils import get_root_logger
+from basicsr.utils.dist_util import master_only
+
+
+class BaseModel():
+ """Base model."""
+
+ def __init__(self, opt):
+ self.opt = opt
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.is_train = opt['is_train']
+ self.schedulers = []
+ self.optimizers = []
+
+ def feed_data(self, data):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ pass
+
+ def save(self, epoch, current_iter):
+ """Save networks and training state."""
+ pass
+
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
+ """Validation function.
+
+ Args:
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
+ current_iter (int): Current iteration.
+ tb_logger (tensorboard logger): Tensorboard logger.
+ save_img (bool): Whether to save images. Default: False.
+ """
+ if self.opt['dist']:
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+ else:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def _initialize_best_metric_results(self, dataset_name):
+ """Initialize the best metric results dict for recording the best metric value and iteration."""
+ if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
+ return
+ elif not hasattr(self, 'best_metric_results'):
+ self.best_metric_results = dict()
+
+ # add a dataset record
+ record = dict()
+ for metric, content in self.opt['val']['metrics'].items():
+ better = content.get('better', 'higher')
+ init_val = float('-inf') if better == 'higher' else float('inf')
+ record[metric] = dict(better=better, val=init_val, iter=-1)
+ self.best_metric_results[dataset_name] = record
+
+ def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
+ if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
+ if val >= self.best_metric_results[dataset_name][metric]['val']:
+ self.best_metric_results[dataset_name][metric]['val'] = val
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
+ else:
+ if val <= self.best_metric_results[dataset_name][metric]['val']:
+ self.best_metric_results[dataset_name][metric]['val'] = val
+ self.best_metric_results[dataset_name][metric]['iter'] = current_iter
+
+ def model_ema(self, decay=0.999):
+ net_g = self.get_bare_model(self.net_g)
+
+ net_g_params = dict(net_g.named_parameters())
+ net_g_ema_params = dict(self.net_g_ema.named_parameters())
+
+ for k in net_g_ema_params.keys():
+ net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
+
+ def get_current_log(self):
+ return self.log_dict
+
+ def model_to_device(self, net):
+ """Model to device. It also warps models with DistributedDataParallel
+ or DataParallel.
+
+ Args:
+ net (nn.Module)
+ """
+ net = net.to(self.device)
+ if self.opt['dist']:
+ find_unused_parameters = self.opt.get('find_unused_parameters', False)
+ net = DistributedDataParallel(
+ net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
+ elif self.opt['num_gpu'] > 1:
+ net = DataParallel(net)
+ return net
+
+ def get_optimizer(self, optim_type, params, lr, **kwargs):
+ if optim_type == 'Adam':
+ optimizer = torch.optim.Adam(params, lr, **kwargs)
+ elif optim_type == 'AdamW':
+ optimizer = torch.optim.AdamW(params, lr, **kwargs)
+ elif optim_type == 'Adamax':
+ optimizer = torch.optim.Adamax(params, lr, **kwargs)
+ elif optim_type == 'SGD':
+ optimizer = torch.optim.SGD(params, lr, **kwargs)
+ elif optim_type == 'ASGD':
+ optimizer = torch.optim.ASGD(params, lr, **kwargs)
+ elif optim_type == 'RMSprop':
+ optimizer = torch.optim.RMSprop(params, lr, **kwargs)
+ elif optim_type == 'Rprop':
+ optimizer = torch.optim.Rprop(params, lr, **kwargs)
+ else:
+ raise NotImplementedError(f'optimizer {optim_type} is not supported yet.')
+ return optimizer
+
+ def setup_schedulers(self):
+ """Set up schedulers."""
+ train_opt = self.opt['train']
+ scheduler_type = train_opt['scheduler'].pop('type')
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
+ for optimizer in self.optimizers:
+ self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
+ elif scheduler_type == 'CosineAnnealingRestartLR':
+ for optimizer in self.optimizers:
+ self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
+ else:
+ raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
+
+ def get_bare_model(self, net):
+ """Get bare model, especially under wrapping with
+ DistributedDataParallel or DataParallel.
+ """
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
+ net = net.module
+ return net
+
+ @master_only
+ def print_network(self, net):
+ """Print the str and parameter number of a network.
+
+ Args:
+ net (nn.Module)
+ """
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
+ net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
+ else:
+ net_cls_str = f'{net.__class__.__name__}'
+
+ net = self.get_bare_model(net)
+ net_str = str(net)
+ net_params = sum(map(lambda x: x.numel(), net.parameters()))
+
+ logger = get_root_logger()
+ logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
+ logger.info(net_str)
+
+ def _set_lr(self, lr_groups_l):
+ """Set learning rate for warm-up.
+
+ Args:
+ lr_groups_l (list): List for lr_groups, each for an optimizer.
+ """
+ for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
+ for param_group, lr in zip(optimizer.param_groups, lr_groups):
+ param_group['lr'] = lr
+
+ def _get_init_lr(self):
+ """Get the initial lr, which is set by the scheduler.
+ """
+ init_lr_groups_l = []
+ for optimizer in self.optimizers:
+ init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
+ return init_lr_groups_l
+
+ def update_learning_rate(self, current_iter, warmup_iter=-1):
+ """Update learning rate.
+
+ Args:
+ current_iter (int): Current iteration.
+ warmup_iter (int): Warm-up iter numbers. -1 for no warm-up.
+ Default: -1.
+ """
+ if current_iter > 1:
+ for scheduler in self.schedulers:
+ scheduler.step()
+ # set up warm-up learning rate
+ if current_iter < warmup_iter:
+ # get initial lr for each group
+ init_lr_g_l = self._get_init_lr()
+ # modify warming-up learning rates
+ # currently only support linearly warm up
+ warm_up_lr_l = []
+ for init_lr_g in init_lr_g_l:
+ warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
+ # set learning rate
+ self._set_lr(warm_up_lr_l)
+
+ def get_current_learning_rate(self):
+ return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
+
+ @master_only
+ def save_network(self, net, net_label, current_iter, param_key='params'):
+ """Save networks.
+
+ Args:
+ net (nn.Module | list[nn.Module]): Network(s) to be saved.
+ net_label (str): Network label.
+ current_iter (int): Current iter number.
+ param_key (str | list[str]): The parameter key(s) to save network.
+ Default: 'params'.
+ """
+ if current_iter == -1:
+ current_iter = 'latest'
+ save_filename = f'{net_label}_{current_iter}.pth'
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
+
+ net = net if isinstance(net, list) else [net]
+ param_key = param_key if isinstance(param_key, list) else [param_key]
+ assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
+
+ save_dict = {}
+ for net_, param_key_ in zip(net, param_key):
+ net_ = self.get_bare_model(net_)
+ state_dict = net_.state_dict()
+ for key, param in state_dict.items():
+ if key.startswith('module.'): # remove unnecessary 'module.'
+ key = key[7:]
+ state_dict[key] = param.cpu()
+ save_dict[param_key_] = state_dict
+
+ # avoid occasional writing errors
+ retry = 3
+ while retry > 0:
+ try:
+ torch.save(save_dict, save_path)
+ except Exception as e:
+ logger = get_root_logger()
+ logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
+ time.sleep(1)
+ else:
+ break
+ finally:
+ retry -= 1
+ if retry == 0:
+ logger.warning(f'Still cannot save {save_path}. Just ignore it.')
+ # raise IOError(f'Cannot save {save_path}.')
+
+ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
+ """Print keys with different name or different size when loading models.
+
+ 1. Print keys with different names.
+ 2. If strict=False, print the same key but with different tensor size.
+ It also ignore these keys with different sizes (not load).
+
+ Args:
+ crt_net (torch model): Current network.
+ load_net (dict): Loaded network.
+ strict (bool): Whether strictly loaded. Default: True.
+ """
+ crt_net = self.get_bare_model(crt_net)
+ crt_net = crt_net.state_dict()
+ crt_net_keys = set(crt_net.keys())
+ load_net_keys = set(load_net.keys())
+
+ logger = get_root_logger()
+ if crt_net_keys != load_net_keys:
+ logger.warning('Current net - loaded net:')
+ for v in sorted(list(crt_net_keys - load_net_keys)):
+ logger.warning(f' {v}')
+ logger.warning('Loaded net - current net:')
+ for v in sorted(list(load_net_keys - crt_net_keys)):
+ logger.warning(f' {v}')
+
+ # check the size for the same keys
+ if not strict:
+ common_keys = crt_net_keys & load_net_keys
+ for k in common_keys:
+ if crt_net[k].size() != load_net[k].size():
+ logger.warning(f'Size different, ignore [{k}]: crt_net: '
+ f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
+ load_net[k + '.ignore'] = load_net.pop(k)
+
+ def load_network(self, net, load_path, strict=True, param_key='params'):
+ """Load network.
+
+ Args:
+ load_path (str): The path of networks to be loaded.
+ net (nn.Module): Network.
+ strict (bool): Whether strictly loaded.
+ param_key (str): The parameter key of loaded network. If set to
+ None, use the root 'path'.
+ Default: 'params'.
+ """
+ logger = get_root_logger()
+ net = self.get_bare_model(net)
+ load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
+ if param_key is not None:
+ if param_key not in load_net and 'params' in load_net:
+ param_key = 'params'
+ logger.info('Loading: params_ema does not exist, use params.')
+ load_net = load_net[param_key]
+ logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
+ # remove unnecessary 'module.'
+ for k, v in deepcopy(load_net).items():
+ if k.startswith('module.'):
+ load_net[k[7:]] = v
+ load_net.pop(k)
+ self._print_different_keys_loading(net, load_net, strict)
+ net.load_state_dict(load_net, strict=strict)
+
+ @master_only
+ def save_training_state(self, epoch, current_iter):
+ """Save training states during training, which will be used for
+ resuming.
+
+ Args:
+ epoch (int): Current epoch.
+ current_iter (int): Current iteration.
+ """
+ if current_iter != -1:
+ state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
+ for o in self.optimizers:
+ state['optimizers'].append(o.state_dict())
+ for s in self.schedulers:
+ state['schedulers'].append(s.state_dict())
+ save_filename = f'{current_iter}.state'
+ save_path = os.path.join(self.opt['path']['training_states'], save_filename)
+
+ # avoid occasional writing errors
+ retry = 3
+ while retry > 0:
+ try:
+ torch.save(state, save_path)
+ except Exception as e:
+ logger = get_root_logger()
+ logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
+ time.sleep(1)
+ else:
+ break
+ finally:
+ retry -= 1
+ if retry == 0:
+ logger.warning(f'Still cannot save {save_path}. Just ignore it.')
+ # raise IOError(f'Cannot save {save_path}.')
+
+ def resume_training(self, resume_state):
+ """Reload the optimizers and schedulers for resumed training.
+
+ Args:
+ resume_state (dict): Resume state.
+ """
+ resume_optimizers = resume_state['optimizers']
+ resume_schedulers = resume_state['schedulers']
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
+ for i, o in enumerate(resume_optimizers):
+ self.optimizers[i].load_state_dict(o)
+ for i, s in enumerate(resume_schedulers):
+ self.schedulers[i].load_state_dict(s)
+
+ def reduce_loss_dict(self, loss_dict):
+ """reduce loss dict.
+
+ In distributed training, it averages the losses among different GPUs .
+
+ Args:
+ loss_dict (OrderedDict): Loss dict.
+ """
+ with torch.no_grad():
+ if self.opt['dist']:
+ keys = []
+ losses = []
+ for name, value in loss_dict.items():
+ keys.append(name)
+ losses.append(value)
+ losses = torch.stack(losses, 0)
+ torch.distributed.reduce(losses, dst=0)
+ if self.opt['rank'] == 0:
+ losses /= self.opt['world_size']
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
+
+ log_dict = OrderedDict()
+ for name, value in loss_dict.items():
+ log_dict[name] = value.mean().item()
+
+ return log_dict
diff --git a/StableSR/basicsr/models/edvr_model.py b/StableSR/basicsr/models/edvr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bdbf7b94fe3f06c76fbf2a4941621f64e0003e7
--- /dev/null
+++ b/StableSR/basicsr/models/edvr_model.py
@@ -0,0 +1,62 @@
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import MODEL_REGISTRY
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class EDVRModel(VideoBaseModel):
+ """EDVR Model.
+
+ Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501
+ """
+
+ def __init__(self, opt):
+ super(EDVRModel, self).__init__(opt)
+ if self.is_train:
+ self.train_tsa_iter = opt['train'].get('tsa_iter')
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ dcn_lr_mul = train_opt.get('dcn_lr_mul', 1)
+ logger = get_root_logger()
+ logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.')
+ if dcn_lr_mul == 1:
+ optim_params = self.net_g.parameters()
+ else: # separate dcn params and normal params for different lr
+ normal_params = []
+ dcn_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'dcn' in name:
+ dcn_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ {
+ 'params': dcn_params,
+ 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul
+ },
+ ]
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def optimize_parameters(self, current_iter):
+ if self.train_tsa_iter:
+ if current_iter == 1:
+ logger = get_root_logger()
+ logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.')
+ for name, param in self.net_g.named_parameters():
+ if 'fusion' not in name:
+ param.requires_grad = False
+ elif current_iter == self.train_tsa_iter:
+ logger = get_root_logger()
+ logger.warning('Train all the parameters.')
+ for param in self.net_g.parameters():
+ param.requires_grad = True
+
+ super(EDVRModel, self).optimize_parameters(current_iter)
diff --git a/StableSR/basicsr/models/esrgan_model.py b/StableSR/basicsr/models/esrgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d746d0e29418d9e8f35fa9c1e3a315d694075be
--- /dev/null
+++ b/StableSR/basicsr/models/esrgan_model.py
@@ -0,0 +1,83 @@
+import torch
+from collections import OrderedDict
+
+from basicsr.utils.registry import MODEL_REGISTRY
+from .srgan_model import SRGANModel
+
+
+@MODEL_REGISTRY.register()
+class ESRGANModel(SRGANModel):
+ """ESRGAN model for single image super-resolution."""
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss (relativistic gan)
+ real_d_pred = self.net_d(self.gt).detach()
+ fake_g_pred = self.net_d(self.output)
+ l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False)
+ l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False)
+ l_g_gan = (l_g_real + l_g_fake) / 2
+
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # gan loss (relativistic gan)
+
+ # In order to avoid the error in distributed training:
+ # "Error detected in CudnnBatchNormBackward: RuntimeError: one of
+ # the variables needed for gradient computation has been modified by
+ # an inplace operation",
+ # we separate the backwards for real and fake, and also detach the
+ # tensor for calculating mean.
+
+ # real
+ fake_d_pred = self.net_d(self.output).detach()
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
diff --git a/StableSR/basicsr/models/hifacegan_model.py b/StableSR/basicsr/models/hifacegan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..435a2b179d6b7c670fe96a83ce45b461300b2c89
--- /dev/null
+++ b/StableSR/basicsr/models/hifacegan_model.py
@@ -0,0 +1,288 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class HiFaceGANModel(SRModel):
+ """HiFaceGAN model for generic-purpose face restoration.
+ No prior modeling required, works for any degradations.
+ Currently doesn't support EMA for inference.
+ """
+
+ def init_training_settings(self):
+
+ train_opt = self.opt['train']
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ raise (NotImplementedError('HiFaceGAN does not support EMA now. Pass'))
+
+ self.net_g.train()
+
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # define losses
+ # HiFaceGAN does not use pixel loss by default
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('feature_matching_opt'):
+ self.cri_feat = build_loss(train_opt['feature_matching_opt']).to(self.device)
+ else:
+ self.cri_feat = None
+
+ if self.cri_pix is None and self.cri_perceptual is None:
+ raise ValueError('Both pixel and perceptual losses are None.')
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def discriminate(self, input_lq, output, ground_truth):
+ """
+ This is a conditional (on the input) discriminator
+ In Batch Normalization, the fake and real images are
+ recommended to be in the same batch to avoid disparate
+ statistics in fake and real images.
+ So both fake and real images are fed to D all at once.
+ """
+ h, w = output.shape[-2:]
+ if output.shape[-2:] != input_lq.shape[-2:]:
+ lq = torch.nn.functional.interpolate(input_lq, (h, w))
+ real = torch.nn.functional.interpolate(ground_truth, (h, w))
+ fake_concat = torch.cat([lq, output], dim=1)
+ real_concat = torch.cat([lq, real], dim=1)
+ else:
+ fake_concat = torch.cat([input_lq, output], dim=1)
+ real_concat = torch.cat([input_lq, ground_truth], dim=1)
+
+ fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
+ discriminator_out = self.net_d(fake_and_real)
+ pred_fake, pred_real = self._divide_pred(discriminator_out)
+ return pred_fake, pred_real
+
+ @staticmethod
+ def _divide_pred(pred):
+ """
+ Take the prediction of fake and real images from the combined batch.
+ The prediction contains the intermediate outputs of multiscale GAN,
+ so it's usually a list
+ """
+ if type(pred) == list:
+ fake = []
+ real = []
+ for p in pred:
+ fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
+ real.append([tensor[tensor.size(0) // 2:] for tensor in p])
+ else:
+ fake = pred[:pred.size(0) // 2]
+ real = pred[pred.size(0) // 2:]
+
+ return fake, real
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+
+ # Requires real prediction for feature matching loss
+ pred_fake, pred_real = self.discriminate(self.lq, self.output, self.gt)
+ l_g_gan = self.cri_gan(pred_fake, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ # feature matching loss
+ if self.cri_feat:
+ l_g_feat = self.cri_feat(pred_fake, pred_real)
+ l_g_total += l_g_feat
+ loss_dict['l_g_feat'] = l_g_feat
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # TODO: Benchmark test between HiFaceGAN and SRGAN implementation:
+ # SRGAN use the same fake output for discriminator update
+ # while HiFaceGAN regenerate a new output using updated net_g
+ # This should not make too much difference though. Stick to SRGAN now.
+ # -------------------------------------------------------------------
+ # ---------- Below are original HiFaceGAN code snippet --------------
+ # -------------------------------------------------------------------
+ # with torch.no_grad():
+ # fake_image = self.net_g(self.lq)
+ # fake_image = fake_image.detach()
+ # fake_image.requires_grad_()
+ # pred_fake, pred_real = self.discriminate(self.lq, fake_image, self.gt)
+
+ # real
+ pred_fake, pred_real = self.discriminate(self.lq, self.output.detach(), self.gt)
+ l_d_real = self.cri_gan(pred_real, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ # fake
+ l_d_fake = self.cri_gan(pred_fake, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+
+ l_d_total = (l_d_real + l_d_fake) / 2
+ l_d_total.backward()
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ print('HiFaceGAN does not support EMA now. pass')
+
+ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
+ """
+ Warning: HiFaceGAN requires train() mode even for validation
+ For more info, see https://github.com/Lotayou/Face-Renovation/issues/31
+
+ Args:
+ dataloader (torch.utils.data.DataLoader): Validation dataloader.
+ current_iter (int): Current iteration.
+ tb_logger (tensorboard logger): Tensorboard logger.
+ save_img (bool): Whether to save images. Default: False.
+ """
+
+ if self.opt['network_g']['type'] in ('HiFaceGAN', 'SPADEGenerator'):
+ self.net_g.train()
+
+ if self.opt['dist']:
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+ else:
+ print('In HiFaceGANModel: The new metrics package is under development.' +
+ 'Using super method now (Only PSNR & SSIM are supported)')
+ super().nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ """
+ TODO: Validation using updated metric system
+ The metrics are now evaluated after all images have been tested
+ This allows batch processing, and also allows evaluation of
+ distributional metrics, such as:
+
+ @ Frechet Inception Distance: FID
+ @ Maximum Mean Discrepancy: MMD
+
+ Warning:
+ Need careful batch management for different inference settings.
+
+ """
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ if with_metrics:
+ self.metric_results = dict() # {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ sr_tensors = []
+ gt_tensors = []
+
+ pbar = tqdm(total=len(dataloader), unit='image')
+ for val_data in dataloader:
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals() # detached cpu tensor, non-squeeze
+ sr_tensors.append(visuals['result'])
+ if 'gt' in visuals:
+ gt_tensors.append(visuals['gt'])
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+
+ imwrite(tensor2img(visuals['result']), save_img_path)
+
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ pbar.close()
+
+ if with_metrics:
+ sr_pack = torch.cat(sr_tensors, dim=0)
+ gt_pack = torch.cat(gt_tensors, dim=0)
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ # The new metric caller automatically returns mean value
+ # FIXME: ERROR: calculate_metric only supports two arguments. Now the codes cannot be successfully run
+ self.metric_results[name] = calculate_metric(dict(sr_pack=sr_pack, gt_pack=gt_pack), opt_)
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def save(self, epoch, current_iter):
+ if hasattr(self, 'net_g_ema'):
+ print('HiFaceGAN does not support EMA now. Fallback to normal mode.')
+
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/StableSR/basicsr/models/lr_scheduler.py b/StableSR/basicsr/models/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e1c6c7a74f5233accda52370f92681d3d3cecf
--- /dev/null
+++ b/StableSR/basicsr/models/lr_scheduler.py
@@ -0,0 +1,96 @@
+import math
+from collections import Counter
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class MultiStepRestartLR(_LRScheduler):
+ """ MultiStep with restarts learning rate scheme.
+
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ milestones (list): Iterations that will decrease learning rate.
+ gamma (float): Decrease ratio. Default: 0.1.
+ restarts (list): Restart iterations. Default: [0].
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+
+ def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ self.restarts = restarts
+ self.restart_weights = restart_weights
+ assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch in self.restarts:
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
+ return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
+ if self.last_epoch not in self.milestones:
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
+
+
+def get_position_from_periods(iteration, cumulative_period):
+ """Get the position from a period list.
+
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_period = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 2.
+
+ Args:
+ iteration (int): Current iteration.
+ cumulative_period (list[int]): Cumulative period list.
+
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_period):
+ if iteration <= period:
+ return i
+
+
+class CosineAnnealingRestartLR(_LRScheduler):
+ """ Cosine annealing with restarts learning rate scheme.
+
+ An example of config:
+ periods = [10, 10, 10, 10]
+ restart_weights = [1, 0.5, 0.5, 0.5]
+ eta_min=1e-7
+
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
+ scheduler will restart with the weights in restart_weights.
+
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ periods (list): Period for each cosine anneling cycle.
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ eta_min (float): The minimum lr. Default: 0.
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+
+ def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
+ self.periods = periods
+ self.restart_weights = restart_weights
+ self.eta_min = eta_min
+ assert (len(self.periods) == len(
+ self.restart_weights)), 'periods and restart_weights should have the same length.'
+ self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
+ current_period = self.periods[idx]
+
+ return [
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
+ (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
+ for base_lr in self.base_lrs
+ ]
diff --git a/StableSR/basicsr/models/realesrgan_model.py b/StableSR/basicsr/models/realesrgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c74b28fb1dc6a7f5c5ad3f7d8bb96c19c52ee92b
--- /dev/null
+++ b/StableSR/basicsr/models/realesrgan_model.py
@@ -0,0 +1,267 @@
+import numpy as np
+import random
+import torch
+from collections import OrderedDict
+from torch.nn import functional as F
+
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from basicsr.data.transforms import paired_random_crop
+from basicsr.losses.loss_util import get_refined_artifact_map
+from basicsr.models.srgan_model import SRGANModel
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.utils.registry import MODEL_REGISTRY
+
+
+@MODEL_REGISTRY.register(suffix='basicsr')
+class RealESRGANModel(SRGANModel):
+ """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ It mainly performs:
+ 1. randomly synthesize LQ images in GPU tensors
+ 2. optimize the networks with GAN training.
+ """
+
+ def __init__(self, opt):
+ super(RealESRGANModel, self).__init__(opt)
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
+ self.queue_size = opt.get('queue_size', 180)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self):
+ """It is the training pair pool for increasing the diversity in a batch.
+
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+ to increase the degradation diversity in a batch.
+ """
+ # initialize
+ b, c, h, w = self.lq.size()
+ if not hasattr(self, 'queue_lr'):
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+ _, c, h, w = self.gt.size()
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+ self.queue_ptr = 0
+ if self.queue_ptr == self.queue_size: # the pool is full
+ # do dequeue and enqueue
+ # shuffle
+ idx = torch.randperm(self.queue_size)
+ self.queue_lr = self.queue_lr[idx]
+ self.queue_gt = self.queue_gt[idx]
+ # get first b samples
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+ # update the queue
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
+
+ self.lq = lq_dequeue
+ self.gt = gt_dequeue
+ else:
+ # only do enqueue
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_ptr = self.queue_ptr + b
+
+ @torch.no_grad()
+ def feed_data(self, data):
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+ """
+ if self.is_train and self.opt.get('high_order_degradation', True):
+ # training data synthesis
+ self.gt = data['gt'].to(self.device)
+ self.gt_usm = self.usm_sharpener(self.gt)
+
+ self.kernel1 = data['kernel1'].to(self.device)
+ self.kernel2 = data['kernel2'].to(self.device)
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
+
+ ori_h, ori_w = self.gt.size()[2:4]
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(self.gt_usm, self.kernel1)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob']
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+ out = self.jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if np.random.uniform() < self.opt['second_blur_prob']:
+ out = filter2D(out, self.kernel2)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob2']
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if np.random.uniform() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+
+ # clamp and round
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+
+ # random crop
+ gt_size = self.opt['gt_size']
+ (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
+ self.opt['scale'])
+
+ # training pair pool
+ self._dequeue_and_enqueue()
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
+ self.gt_usm = self.usm_sharpener(self.gt)
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
+ else:
+ # for paired training or validation
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+ self.gt_usm = self.usm_sharpener(self.gt)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ # do not use the synthetic process during validation
+ self.is_train = False
+ super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
+ self.is_train = True
+
+ def optimize_parameters(self, current_iter):
+ # usm sharpening
+ l1_gt = self.gt_usm
+ percep_gt = self.gt_usm
+ gan_gt = self.gt_usm
+ if self.opt['l1_gt_usm'] is False:
+ l1_gt = self.gt
+ if self.opt['percep_gt_usm'] is False:
+ percep_gt = self.gt
+ if self.opt['gan_gt_usm'] is False:
+ gan_gt = self.gt
+
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+ if self.cri_ldl:
+ self.output_ema = self.net_g_ema(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, l1_gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ if self.cri_ldl:
+ pixel_weight = get_refined_artifact_map(self.gt, self.output, self.output_ema, 7)
+ l_g_ldl = self.cri_ldl(torch.mul(pixel_weight, self.output), torch.mul(pixel_weight, self.gt))
+ l_g_total += l_g_ldl
+ loss_dict['l_g_ldl'] = l_g_ldl
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ real_d_pred = self.net_d(gan_gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
diff --git a/StableSR/basicsr/models/realesrnet_model.py b/StableSR/basicsr/models/realesrnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5790918b969682a0db0e2ed9236b7046d627b90
--- /dev/null
+++ b/StableSR/basicsr/models/realesrnet_model.py
@@ -0,0 +1,189 @@
+import numpy as np
+import random
+import torch
+from torch.nn import functional as F
+
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from basicsr.data.transforms import paired_random_crop
+from basicsr.models.sr_model import SRModel
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.utils.registry import MODEL_REGISTRY
+
+
+@MODEL_REGISTRY.register(suffix='basicsr')
+class RealESRNetModel(SRModel):
+ """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
+
+ It is trained without GAN losses.
+ It mainly performs:
+ 1. randomly synthesize LQ images in GPU tensors
+ 2. optimize the networks with GAN training.
+ """
+
+ def __init__(self, opt):
+ super(RealESRNetModel, self).__init__(opt)
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
+ self.queue_size = opt.get('queue_size', 180)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self):
+ """It is the training pair pool for increasing the diversity in a batch.
+
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+ to increase the degradation diversity in a batch.
+ """
+ # initialize
+ b, c, h, w = self.lq.size()
+ if not hasattr(self, 'queue_lr'):
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+ _, c, h, w = self.gt.size()
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+ self.queue_ptr = 0
+ if self.queue_ptr == self.queue_size: # the pool is full
+ # do dequeue and enqueue
+ # shuffle
+ idx = torch.randperm(self.queue_size)
+ self.queue_lr = self.queue_lr[idx]
+ self.queue_gt = self.queue_gt[idx]
+ # get first b samples
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+ # update the queue
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
+
+ self.lq = lq_dequeue
+ self.gt = gt_dequeue
+ else:
+ # only do enqueue
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_ptr = self.queue_ptr + b
+
+ @torch.no_grad()
+ def feed_data(self, data):
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
+ """
+ if self.is_train and self.opt.get('high_order_degradation', True):
+ # training data synthesis
+ self.gt = data['gt'].to(self.device)
+ # USM sharpen the GT images
+ if self.opt['gt_usm'] is True:
+ self.gt = self.usm_sharpener(self.gt)
+
+ self.kernel1 = data['kernel1'].to(self.device)
+ self.kernel2 = data['kernel2'].to(self.device)
+ self.sinc_kernel = data['sinc_kernel'].to(self.device)
+
+ ori_h, ori_w = self.gt.size()[2:4]
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(self.gt, self.kernel1)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob']
+ if np.random.uniform() < self.opt['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+ out = self.jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if np.random.uniform() < self.opt['second_blur_prob']:
+ out = filter2D(out, self.kernel2)
+ # random resize
+ updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
+ if updown_type == 'up':
+ scale = np.random.uniform(1, self.opt['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = np.random.uniform(self.opt['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+ # add noise
+ gray_noise_prob = self.opt['gray_noise_prob2']
+ if np.random.uniform() < self.opt['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.opt['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if np.random.uniform() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
+ out = filter2D(out, self.sinc_kernel)
+
+ # clamp and round
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+
+ # random crop
+ gt_size = self.opt['gt_size']
+ self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
+
+ # training pair pool
+ self._dequeue_and_enqueue()
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
+ else:
+ # for paired training or validation
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+ self.gt_usm = self.usm_sharpener(self.gt)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ # do not use the synthetic process during validation
+ self.is_train = False
+ super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
+ self.is_train = True
diff --git a/StableSR/basicsr/models/sr_model.py b/StableSR/basicsr/models/sr_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..787f1fd2eab5963579c764c1bfb87199b7dd196f
--- /dev/null
+++ b/StableSR/basicsr/models/sr_model.py
@@ -0,0 +1,279 @@
+import torch
+from collections import OrderedDict
+from os import path as osp
+from tqdm import tqdm
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+from .base_model import BaseModel
+
+
+@MODEL_REGISTRY.register()
+class SRModel(BaseModel):
+ """Base SR model for single image super-resolution."""
+
+ def __init__(self, opt):
+ super(SRModel, self).__init__(opt)
+
+ # define network
+ self.net_g = build_network(opt['network_g'])
+ self.net_g = self.model_to_device(self.net_g)
+ self.print_network(self.net_g)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_g', 'params')
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
+
+ if self.is_train:
+ self.init_training_settings()
+
+ def init_training_settings(self):
+ self.net_g.train()
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger = get_root_logger()
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if self.cri_pix is None and self.cri_perceptual is None:
+ raise ValueError('Both pixel and perceptual losses are None.')
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ optim_params = []
+ for k, v in self.net_g.named_parameters():
+ if v.requires_grad:
+ optim_params.append(v)
+ else:
+ logger = get_root_logger()
+ logger.warning(f'Params {k} will not be optimized.')
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def feed_data(self, data):
+ self.lq = data['lq'].to(self.device)
+ if 'gt' in data:
+ self.gt = data['gt'].to(self.device)
+
+ def optimize_parameters(self, current_iter):
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_total = 0
+ loss_dict = OrderedDict()
+ # pixel loss
+ if self.cri_pix:
+ l_pix = self.cri_pix(self.output, self.gt)
+ l_total += l_pix
+ loss_dict['l_pix'] = l_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_percep, l_style = self.cri_perceptual(self.output, self.gt)
+ if l_percep is not None:
+ l_total += l_percep
+ loss_dict['l_percep'] = l_percep
+ if l_style is not None:
+ l_total += l_style
+ loss_dict['l_style'] = l_style
+
+ l_total.backward()
+ self.optimizer_g.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ def test(self):
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ with torch.no_grad():
+ self.output = self.net_g_ema(self.lq)
+ else:
+ self.net_g.eval()
+ with torch.no_grad():
+ self.output = self.net_g(self.lq)
+ self.net_g.train()
+
+ def test_selfensemble(self):
+ # TODO: to be tested
+ # 8 augmentations
+ # modified from https://github.com/thstkdgus35/EDSR-PyTorch
+
+ def _transform(v, op):
+ # if self.precision != 'single': v = v.float()
+ v2np = v.data.cpu().numpy()
+ if op == 'v':
+ tfnp = v2np[:, :, :, ::-1].copy()
+ elif op == 'h':
+ tfnp = v2np[:, :, ::-1, :].copy()
+ elif op == 't':
+ tfnp = v2np.transpose((0, 1, 3, 2)).copy()
+
+ ret = torch.Tensor(tfnp).to(self.device)
+ # if self.precision == 'half': ret = ret.half()
+
+ return ret
+
+ # prepare augmented data
+ lq_list = [self.lq]
+ for tf in 'v', 'h', 't':
+ lq_list.extend([_transform(t, tf) for t in lq_list])
+
+ # inference
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ with torch.no_grad():
+ out_list = [self.net_g_ema(aug) for aug in lq_list]
+ else:
+ self.net_g.eval()
+ with torch.no_grad():
+ out_list = [self.net_g_ema(aug) for aug in lq_list]
+ self.net_g.train()
+
+ # merge results
+ for i in range(len(out_list)):
+ if i > 3:
+ out_list[i] = _transform(out_list[i], 't')
+ if i % 4 > 1:
+ out_list[i] = _transform(out_list[i], 'h')
+ if (i % 4) % 2 == 1:
+ out_list[i] = _transform(out_list[i], 'v')
+ output = torch.cat(out_list, dim=0)
+
+ self.output = output.mean(dim=0, keepdim=True)
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset_name = dataloader.dataset.opt['name']
+ with_metrics = self.opt['val'].get('metrics') is not None
+ use_pbar = self.opt['val'].get('pbar', False)
+
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
+ self._initialize_best_metric_results(dataset_name)
+ # zero self.metric_results
+ if with_metrics:
+ self.metric_results = {metric: 0 for metric in self.metric_results}
+
+ metric_data = dict()
+ if use_pbar:
+ pbar = tqdm(total=len(dataloader), unit='image')
+
+ for idx, val_data in enumerate(dataloader):
+ img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
+ self.feed_data(val_data)
+ self.test()
+
+ visuals = self.get_current_visuals()
+ sr_img = tensor2img([visuals['result']])
+ metric_data['img'] = sr_img
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ metric_data['img2'] = gt_img
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], img_name,
+ f'{img_name}_{current_iter}.png')
+ else:
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(sr_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for name, opt_ in self.opt['val']['metrics'].items():
+ self.metric_results[name] += calculate_metric(metric_data, opt_)
+ if use_pbar:
+ pbar.update(1)
+ pbar.set_description(f'Test {img_name}')
+ if use_pbar:
+ pbar.close()
+
+ if with_metrics:
+ for metric in self.metric_results.keys():
+ self.metric_results[metric] /= (idx + 1)
+ # update the best metric result
+ self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
+
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ log_str = f'Validation {dataset_name}\n'
+ for metric, value in self.metric_results.items():
+ log_str += f'\t # {metric}: {value:.4f}'
+ if hasattr(self, 'best_metric_results'):
+ log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
+ log_str += '\n'
+
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric, value in self.metric_results.items():
+ tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
+
+ def get_current_visuals(self):
+ out_dict = OrderedDict()
+ out_dict['lq'] = self.lq.detach().cpu()
+ out_dict['result'] = self.output.detach().cpu()
+ if hasattr(self, 'gt'):
+ out_dict['gt'] = self.gt.detach().cpu()
+ return out_dict
+
+ def save(self, epoch, current_iter):
+ if hasattr(self, 'net_g_ema'):
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/StableSR/basicsr/models/srgan_model.py b/StableSR/basicsr/models/srgan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..45387ca7908e3f38f59a605adb8242ad12fcf1a1
--- /dev/null
+++ b/StableSR/basicsr/models/srgan_model.py
@@ -0,0 +1,149 @@
+import torch
+from collections import OrderedDict
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class SRGANModel(SRModel):
+ """SRGAN model for single image super-resolution."""
+
+ def init_training_settings(self):
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger = get_root_logger()
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema is used only for testing on one GPU and saving
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_d', 'params')
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+ self.net_g.train()
+ self.net_d.train()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('ldl_opt'):
+ self.cri_ldl = build_loss(train_opt['ldl_opt']).to(self.device)
+ else:
+ self.cri_ldl = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, self.net_g.parameters(), **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss
+ fake_g_pred = self.net_d(self.output)
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ def save(self, epoch, current_iter):
+ if hasattr(self, 'net_g_ema'):
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/StableSR/basicsr/models/stylegan2_model.py b/StableSR/basicsr/models/stylegan2_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7da708122160f2be51a98a6a635349f34ee042e
--- /dev/null
+++ b/StableSR/basicsr/models/stylegan2_model.py
@@ -0,0 +1,283 @@
+import cv2
+import math
+import numpy as np
+import random
+import torch
+from collections import OrderedDict
+from os import path as osp
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.losses.gan_loss import g_path_regularize, r1_penalty
+from basicsr.utils import imwrite, tensor2img
+from basicsr.utils.registry import MODEL_REGISTRY
+from .base_model import BaseModel
+
+
+@MODEL_REGISTRY.register()
+class StyleGAN2Model(BaseModel):
+ """StyleGAN2 model."""
+
+ def __init__(self, opt):
+ super(StyleGAN2Model, self).__init__(opt)
+
+ # define network net_g
+ self.net_g = build_network(opt['network_g'])
+ self.net_g = self.model_to_device(self.net_g)
+ self.print_network(self.net_g)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_g', 'params')
+ self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
+
+ # latent dimension: self.num_style_feat
+ self.num_style_feat = opt['network_g']['num_style_feat']
+ num_val_samples = self.opt['val'].get('num_val_samples', 16)
+ self.fixed_sample = torch.randn(num_val_samples, self.num_style_feat, device=self.device)
+
+ if self.is_train:
+ self.init_training_settings()
+
+ def init_training_settings(self):
+ train_opt = self.opt['train']
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_d', 'params')
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+ # define network net_g with Exponential Moving Average (EMA)
+ # net_g_ema only used for testing on one GPU and saving, do not need to
+ # wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+
+ self.net_g.train()
+ self.net_d.train()
+ self.net_g_ema.eval()
+
+ # define losses
+ # gan loss (wgan)
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+ # regularization weights
+ self.r1_reg_weight = train_opt['r1_reg_weight'] # for discriminator
+ self.path_reg_weight = train_opt['path_reg_weight'] # for generator
+
+ self.net_g_reg_every = train_opt['net_g_reg_every']
+ self.net_d_reg_every = train_opt['net_d_reg_every']
+ self.mixing_prob = train_opt['mixing_prob']
+
+ self.mean_path_length = 0
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ # optimizer g
+ net_g_reg_ratio = self.net_g_reg_every / (self.net_g_reg_every + 1)
+ if self.opt['network_g']['type'] == 'StyleGAN2GeneratorC':
+ normal_params = []
+ style_mlp_params = []
+ modulation_conv_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'modulation' in name:
+ normal_params.append(param)
+ elif 'style_mlp' in name:
+ style_mlp_params.append(param)
+ elif 'modulated_conv' in name:
+ modulation_conv_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params_g = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ {
+ 'params': style_mlp_params,
+ 'lr': train_opt['optim_g']['lr'] * 0.01
+ },
+ {
+ 'params': modulation_conv_params,
+ 'lr': train_opt['optim_g']['lr'] / 3
+ }
+ ]
+ else:
+ normal_params = []
+ for name, param in self.net_g.named_parameters():
+ normal_params.append(param)
+ optim_params_g = [{ # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ }]
+
+ optim_type = train_opt['optim_g'].pop('type')
+ lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
+ betas = (0**net_g_reg_ratio, 0.99**net_g_reg_ratio)
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, lr, betas=betas)
+ self.optimizers.append(self.optimizer_g)
+
+ # optimizer d
+ net_d_reg_ratio = self.net_d_reg_every / (self.net_d_reg_every + 1)
+ if self.opt['network_d']['type'] == 'StyleGAN2DiscriminatorC':
+ normal_params = []
+ linear_params = []
+ for name, param in self.net_d.named_parameters():
+ if 'final_linear' in name:
+ linear_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params_d = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_d']['lr']
+ },
+ {
+ 'params': linear_params,
+ 'lr': train_opt['optim_d']['lr'] * (1 / math.sqrt(512))
+ }
+ ]
+ else:
+ normal_params = []
+ for name, param in self.net_d.named_parameters():
+ normal_params.append(param)
+ optim_params_d = [{ # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_d']['lr']
+ }]
+
+ optim_type = train_opt['optim_d'].pop('type')
+ lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
+ betas = (0**net_d_reg_ratio, 0.99**net_d_reg_ratio)
+ self.optimizer_d = self.get_optimizer(optim_type, optim_params_d, lr, betas=betas)
+ self.optimizers.append(self.optimizer_d)
+
+ def feed_data(self, data):
+ self.real_img = data['gt'].to(self.device)
+
+ def make_noise(self, batch, num_noise):
+ if num_noise == 1:
+ noises = torch.randn(batch, self.num_style_feat, device=self.device)
+ else:
+ noises = torch.randn(num_noise, batch, self.num_style_feat, device=self.device).unbind(0)
+ return noises
+
+ def mixing_noise(self, batch, prob):
+ if random.random() < prob:
+ return self.make_noise(batch, 2)
+ else:
+ return [self.make_noise(batch, 1)]
+
+ def optimize_parameters(self, current_iter):
+ loss_dict = OrderedDict()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+ self.optimizer_d.zero_grad()
+
+ batch = self.real_img.size(0)
+ noise = self.mixing_noise(batch, self.mixing_prob)
+ fake_img, _ = self.net_g(noise)
+ fake_pred = self.net_d(fake_img.detach())
+
+ real_pred = self.net_d(self.real_img)
+ # wgan loss with softplus (logistic loss) for discriminator
+ l_d = self.cri_gan(real_pred, True, is_disc=True) + self.cri_gan(fake_pred, False, is_disc=True)
+ loss_dict['l_d'] = l_d
+ # In wgan, real_score should be positive and fake_score should be
+ # negative
+ loss_dict['real_score'] = real_pred.detach().mean()
+ loss_dict['fake_score'] = fake_pred.detach().mean()
+ l_d.backward()
+
+ if current_iter % self.net_d_reg_every == 0:
+ self.real_img.requires_grad = True
+ real_pred = self.net_d(self.real_img)
+ l_d_r1 = r1_penalty(real_pred, self.real_img)
+ l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
+ # TODO: why do we need to add 0 * real_pred, otherwise, a runtime
+ # error will arise: RuntimeError: Expected to have finished
+ # reduction in the prior iteration before starting a new one.
+ # This error indicates that your module has parameters that were
+ # not used in producing loss.
+ loss_dict['l_d_r1'] = l_d_r1.detach().mean()
+ l_d_r1.backward()
+
+ self.optimizer_d.step()
+
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+ self.optimizer_g.zero_grad()
+
+ noise = self.mixing_noise(batch, self.mixing_prob)
+ fake_img, _ = self.net_g(noise)
+ fake_pred = self.net_d(fake_img)
+
+ # wgan loss with softplus (non-saturating loss) for generator
+ l_g = self.cri_gan(fake_pred, True, is_disc=False)
+ loss_dict['l_g'] = l_g
+ l_g.backward()
+
+ if current_iter % self.net_g_reg_every == 0:
+ path_batch_size = max(1, batch // self.opt['train']['path_batch_shrink'])
+ noise = self.mixing_noise(path_batch_size, self.mixing_prob)
+ fake_img, latents = self.net_g(noise, return_latents=True)
+ l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length)
+
+ l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0])
+ # TODO: why do we need to add 0 * fake_img[0, 0, 0, 0]
+ l_g_path.backward()
+ loss_dict['l_g_path'] = l_g_path.detach().mean()
+ loss_dict['path_length'] = path_lengths
+
+ self.optimizer_g.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ # EMA
+ self.model_ema(decay=0.5**(32 / (10 * 1000)))
+
+ def test(self):
+ with torch.no_grad():
+ self.net_g_ema.eval()
+ self.output, _ = self.net_g_ema([self.fixed_sample])
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ if self.opt['rank'] == 0:
+ self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ assert dataloader is None, 'Validation dataloader should be None.'
+ self.test()
+ result = tensor2img(self.output, min_max=(-1, 1))
+ if self.opt['is_train']:
+ save_img_path = osp.join(self.opt['path']['visualization'], 'train', f'train_{current_iter}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png')
+ imwrite(result, save_img_path)
+ # add sample images to tb_logger
+ result = (result / 255.).astype(np.float32)
+ result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
+ if tb_logger is not None:
+ tb_logger.add_image('samples', result, global_step=current_iter, dataformats='HWC')
+
+ def save(self, epoch, current_iter):
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/StableSR/basicsr/models/swinir_model.py b/StableSR/basicsr/models/swinir_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac182f23b4a300aff14b2b45fcdca8c00da90c1
--- /dev/null
+++ b/StableSR/basicsr/models/swinir_model.py
@@ -0,0 +1,33 @@
+import torch
+from torch.nn import functional as F
+
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class SwinIRModel(SRModel):
+
+ def test(self):
+ # pad to multiplication of window_size
+ window_size = self.opt['network_g']['window_size']
+ scale = self.opt.get('scale', 1)
+ mod_pad_h, mod_pad_w = 0, 0
+ _, _, h, w = self.lq.size()
+ if h % window_size != 0:
+ mod_pad_h = window_size - h % window_size
+ if w % window_size != 0:
+ mod_pad_w = window_size - w % window_size
+ img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ with torch.no_grad():
+ self.output = self.net_g_ema(img)
+ else:
+ self.net_g.eval()
+ with torch.no_grad():
+ self.output = self.net_g(img)
+ self.net_g.train()
+
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
diff --git a/StableSR/basicsr/models/video_base_model.py b/StableSR/basicsr/models/video_base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f7993a15e585526135d1ede094f4dcff47f64db
--- /dev/null
+++ b/StableSR/basicsr/models/video_base_model.py
@@ -0,0 +1,160 @@
+import torch
+from collections import Counter
+from os import path as osp
+from torch import distributed as dist
+from tqdm import tqdm
+
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class VideoBaseModel(SRModel):
+ """Base video SR model."""
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset = dataloader.dataset
+ dataset_name = dataset.opt['name']
+ with_metrics = self.opt['val']['metrics'] is not None
+ # initialize self.metric_results
+ # It is a dict: {
+ # 'folder1': tensor (num_frame x len(metrics)),
+ # 'folder2': tensor (num_frame x len(metrics))
+ # }
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {}
+ num_frame_each_folder = Counter(dataset.data_info['folder'])
+ for folder, num_frame in num_frame_each_folder.items():
+ self.metric_results[folder] = torch.zeros(
+ num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+ # initialize the best metric results
+ self._initialize_best_metric_results(dataset_name)
+ # zero self.metric_results
+ rank, world_size = get_dist_info()
+ if with_metrics:
+ for _, tensor in self.metric_results.items():
+ tensor.zero_()
+
+ metric_data = dict()
+ # record all frames (border and center frames)
+ if rank == 0:
+ pbar = tqdm(total=len(dataset), unit='frame')
+ for idx in range(rank, len(dataset), world_size):
+ val_data = dataset[idx]
+ val_data['lq'].unsqueeze_(0)
+ val_data['gt'].unsqueeze_(0)
+ folder = val_data['folder']
+ frame_idx, max_idx = val_data['idx'].split('/')
+ lq_path = val_data['lq_path']
+
+ self.feed_data(val_data)
+ self.test()
+ visuals = self.get_current_visuals()
+ result_img = tensor2img([visuals['result']])
+ metric_data['img'] = result_img
+ if 'gt' in visuals:
+ gt_img = tensor2img([visuals['gt']])
+ metric_data['img2'] = gt_img
+ del self.gt
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ torch.cuda.empty_cache()
+
+ if save_img:
+ if self.opt['is_train']:
+ raise NotImplementedError('saving image is not supported during training.')
+ else:
+ if 'vimeo' in dataset_name.lower(): # vimeo90k dataset
+ split_result = lq_path.split('/')
+ img_name = f'{split_result[-3]}_{split_result[-2]}_{split_result[-1].split(".")[0]}'
+ else: # other datasets, e.g., REDS, Vid4
+ img_name = osp.splitext(osp.basename(lq_path))[0]
+
+ if self.opt['val']['suffix']:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ else:
+ save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f'{img_name}_{self.opt["name"]}.png')
+ imwrite(result_img, save_img_path)
+
+ if with_metrics:
+ # calculate metrics
+ for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
+ result = calculate_metric(metric_data, opt_)
+ self.metric_results[folder][int(frame_idx), metric_idx] += result
+
+ # progress bar
+ if rank == 0:
+ for _ in range(world_size):
+ pbar.update(1)
+ pbar.set_description(f'Test {folder}: {int(frame_idx) + world_size}/{max_idx}')
+ if rank == 0:
+ pbar.close()
+
+ if with_metrics:
+ if self.opt['dist']:
+ # collect data among GPUs
+ for _, tensor in self.metric_results.items():
+ dist.reduce(tensor, 0)
+ dist.barrier()
+ else:
+ pass # assume use one gpu in non-dist testing
+
+ if rank == 0:
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ logger = get_root_logger()
+ logger.warning('nondist_validation is not implemented. Run dist_validation.')
+ self.dist_validation(dataloader, current_iter, tb_logger, save_img)
+
+ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
+ # ----------------- calculate the average values for each folder, and for each metric ----------------- #
+ # average all frames for each sub-folder
+ # metric_results_avg is a dict:{
+ # 'folder1': tensor (len(metrics)),
+ # 'folder2': tensor (len(metrics))
+ # }
+ metric_results_avg = {
+ folder: torch.mean(tensor, dim=0).cpu()
+ for (folder, tensor) in self.metric_results.items()
+ }
+ # total_avg_results is a dict: {
+ # 'metric1': float,
+ # 'metric2': float
+ # }
+ total_avg_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ for folder, tensor in metric_results_avg.items():
+ for idx, metric in enumerate(total_avg_results.keys()):
+ total_avg_results[metric] += metric_results_avg[folder][idx].item()
+ # average among folders
+ for metric in total_avg_results.keys():
+ total_avg_results[metric] /= len(metric_results_avg)
+ # update the best metric result
+ self._update_best_metric_result(dataset_name, metric, total_avg_results[metric], current_iter)
+
+ # ------------------------------------------ log the metric ------------------------------------------ #
+ log_str = f'Validation {dataset_name}\n'
+ for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
+ log_str += f'\t # {metric}: {value:.4f}'
+ for folder, tensor in metric_results_avg.items():
+ log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}'
+ if hasattr(self, 'best_metric_results'):
+ log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
+ log_str += '\n'
+
+ logger = get_root_logger()
+ logger.info(log_str)
+ if tb_logger:
+ for metric_idx, (metric, value) in enumerate(total_avg_results.items()):
+ tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
+ for folder, tensor in metric_results_avg.items():
+ tb_logger.add_scalar(f'metrics/{metric}/{folder}', tensor[metric_idx].item(), current_iter)
diff --git a/StableSR/basicsr/models/video_gan_model.py b/StableSR/basicsr/models/video_gan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2adcdeee59e494dd7d1c285919fac5c99cd9efb
--- /dev/null
+++ b/StableSR/basicsr/models/video_gan_model.py
@@ -0,0 +1,19 @@
+from basicsr.utils.registry import MODEL_REGISTRY
+from .srgan_model import SRGANModel
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class VideoGANModel(SRGANModel, VideoBaseModel):
+ """Video GAN model.
+
+ Use multiple inheritance.
+ It will first use the functions of :class:`SRGANModel`:
+
+ - :func:`init_training_settings`
+ - :func:`setup_optimizers`
+ - :func:`optimize_parameters`
+ - :func:`save`
+
+ Then find functions in :class:`VideoBaseModel`.
+ """
diff --git a/StableSR/basicsr/models/video_recurrent_gan_model.py b/StableSR/basicsr/models/video_recurrent_gan_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..74cf81145c50ffafb220d22b51e56746dee5ba41
--- /dev/null
+++ b/StableSR/basicsr/models/video_recurrent_gan_model.py
@@ -0,0 +1,180 @@
+import torch
+from collections import OrderedDict
+
+from basicsr.archs import build_network
+from basicsr.losses import build_loss
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import MODEL_REGISTRY
+from .video_recurrent_model import VideoRecurrentModel
+
+
+@MODEL_REGISTRY.register()
+class VideoRecurrentGANModel(VideoRecurrentModel):
+
+ def init_training_settings(self):
+ train_opt = self.opt['train']
+
+ self.ema_decay = train_opt.get('ema_decay', 0)
+ if self.ema_decay > 0:
+ logger = get_root_logger()
+ logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
+ # build network net_g with Exponential Moving Average (EMA)
+ # net_g_ema only used for testing on one GPU and saving.
+ # There is no need to wrap with DistributedDataParallel
+ self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
+ # load pretrained model
+ load_path = self.opt['path'].get('pretrain_network_g', None)
+ if load_path is not None:
+ self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
+ else:
+ self.model_ema(0) # copy net_g weight
+ self.net_g_ema.eval()
+
+ # define network net_d
+ self.net_d = build_network(self.opt['network_d'])
+ self.net_d = self.model_to_device(self.net_d)
+ self.print_network(self.net_d)
+
+ # load pretrained models
+ load_path = self.opt['path'].get('pretrain_network_d', None)
+ if load_path is not None:
+ param_key = self.opt['path'].get('param_key_d', 'params')
+ self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)
+
+ self.net_g.train()
+ self.net_d.train()
+
+ # define losses
+ if train_opt.get('pixel_opt'):
+ self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
+ else:
+ self.cri_pix = None
+
+ if train_opt.get('perceptual_opt'):
+ self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
+ else:
+ self.cri_perceptual = None
+
+ if train_opt.get('gan_opt'):
+ self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
+
+ self.net_d_iters = train_opt.get('net_d_iters', 1)
+ self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
+
+ # set up optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ if train_opt['fix_flow']:
+ normal_params = []
+ flow_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name: # The fix_flow now only works for spynet.
+ flow_params.append(param)
+ else:
+ normal_params.append(param)
+
+ optim_params = [
+ { # add flow params first
+ 'params': flow_params,
+ 'lr': train_opt['lr_flow']
+ },
+ {
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ ]
+ else:
+ optim_params = self.net_g.parameters()
+
+ # optimizer g
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+ # optimizer d
+ optim_type = train_opt['optim_d'].pop('type')
+ self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
+ self.optimizers.append(self.optimizer_d)
+
+ def optimize_parameters(self, current_iter):
+ logger = get_root_logger()
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ if self.fix_flow_iter:
+ if current_iter == 1:
+ logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name or 'edvr' in name:
+ param.requires_grad_(False)
+ elif current_iter == self.fix_flow_iter:
+ logger.warning('Train all the parameters.')
+ self.net_g.requires_grad_(True)
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ _, _, c, h, w = self.output.size()
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output.view(-1, c, h, w), self.gt.view(-1, c, h, w))
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss
+ fake_g_pred = self.net_d(self.output.view(-1, c, h, w))
+ l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # real
+ # reshape to (b*n, c, h, w)
+ real_d_pred = self.net_d(self.gt.view(-1, c, h, w))
+ l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ l_d_real.backward()
+ # fake
+ # reshape to (b*n, c, h, w)
+ fake_d_pred = self.net_d(self.output.view(-1, c, h, w).detach())
+ l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
+
+ def save(self, epoch, current_iter):
+ if self.ema_decay > 0:
+ self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
+ else:
+ self.save_network(self.net_g, 'net_g', current_iter)
+ self.save_network(self.net_d, 'net_d', current_iter)
+ self.save_training_state(epoch, current_iter)
diff --git a/StableSR/basicsr/models/video_recurrent_model.py b/StableSR/basicsr/models/video_recurrent_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..796ee57d5aeb84e81fe8dc769facc8339798cc3e
--- /dev/null
+++ b/StableSR/basicsr/models/video_recurrent_model.py
@@ -0,0 +1,197 @@
+import torch
+from collections import Counter
+from os import path as osp
+from torch import distributed as dist
+from tqdm import tqdm
+
+from basicsr.metrics import calculate_metric
+from basicsr.utils import get_root_logger, imwrite, tensor2img
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import MODEL_REGISTRY
+from .video_base_model import VideoBaseModel
+
+
+@MODEL_REGISTRY.register()
+class VideoRecurrentModel(VideoBaseModel):
+
+ def __init__(self, opt):
+ super(VideoRecurrentModel, self).__init__(opt)
+ if self.is_train:
+ self.fix_flow_iter = opt['train'].get('fix_flow')
+
+ def setup_optimizers(self):
+ train_opt = self.opt['train']
+ flow_lr_mul = train_opt.get('flow_lr_mul', 1)
+ logger = get_root_logger()
+ logger.info(f'Multiple the learning rate for flow network with {flow_lr_mul}.')
+ if flow_lr_mul == 1:
+ optim_params = self.net_g.parameters()
+ else: # separate flow params and normal params for different lr
+ normal_params = []
+ flow_params = []
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name:
+ flow_params.append(param)
+ else:
+ normal_params.append(param)
+ optim_params = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr']
+ },
+ {
+ 'params': flow_params,
+ 'lr': train_opt['optim_g']['lr'] * flow_lr_mul
+ },
+ ]
+
+ optim_type = train_opt['optim_g'].pop('type')
+ self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
+ self.optimizers.append(self.optimizer_g)
+
+ def optimize_parameters(self, current_iter):
+ if self.fix_flow_iter:
+ logger = get_root_logger()
+ if current_iter == 1:
+ logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
+ for name, param in self.net_g.named_parameters():
+ if 'spynet' in name or 'edvr' in name:
+ param.requires_grad_(False)
+ elif current_iter == self.fix_flow_iter:
+ logger.warning('Train all the parameters.')
+ self.net_g.requires_grad_(True)
+
+ super(VideoRecurrentModel, self).optimize_parameters(current_iter)
+
+ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
+ dataset = dataloader.dataset
+ dataset_name = dataset.opt['name']
+ with_metrics = self.opt['val']['metrics'] is not None
+ # initialize self.metric_results
+ # It is a dict: {
+ # 'folder1': tensor (num_frame x len(metrics)),
+ # 'folder2': tensor (num_frame x len(metrics))
+ # }
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {}
+ num_frame_each_folder = Counter(dataset.data_info['folder'])
+ for folder, num_frame in num_frame_each_folder.items():
+ self.metric_results[folder] = torch.zeros(
+ num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+ # initialize the best metric results
+ self._initialize_best_metric_results(dataset_name)
+ # zero self.metric_results
+ rank, world_size = get_dist_info()
+ if with_metrics:
+ for _, tensor in self.metric_results.items():
+ tensor.zero_()
+
+ metric_data = dict()
+ num_folders = len(dataset)
+ num_pad = (world_size - (num_folders % world_size)) % world_size
+ if rank == 0:
+ pbar = tqdm(total=len(dataset), unit='folder')
+ # Will evaluate (num_folders + num_pad) times, but only the first num_folders results will be recorded.
+ # (To avoid wait-dead)
+ for i in range(rank, num_folders + num_pad, world_size):
+ idx = min(i, num_folders - 1)
+ val_data = dataset[idx]
+ folder = val_data['folder']
+
+ # compute outputs
+ val_data['lq'].unsqueeze_(0)
+ val_data['gt'].unsqueeze_(0)
+ self.feed_data(val_data)
+ val_data['lq'].squeeze_(0)
+ val_data['gt'].squeeze_(0)
+
+ self.test()
+ visuals = self.get_current_visuals()
+
+ # tentative for out of GPU memory
+ del self.lq
+ del self.output
+ if 'gt' in visuals:
+ del self.gt
+ torch.cuda.empty_cache()
+
+ if self.center_frame_only:
+ visuals['result'] = visuals['result'].unsqueeze(1)
+ if 'gt' in visuals:
+ visuals['gt'] = visuals['gt'].unsqueeze(1)
+
+ # evaluate
+ if i < num_folders:
+ for idx in range(visuals['result'].size(1)):
+ result = visuals['result'][0, idx, :, :, :]
+ result_img = tensor2img([result]) # uint8, bgr
+ metric_data['img'] = result_img
+ if 'gt' in visuals:
+ gt = visuals['gt'][0, idx, :, :, :]
+ gt_img = tensor2img([gt]) # uint8, bgr
+ metric_data['img2'] = gt_img
+
+ if save_img:
+ if self.opt['is_train']:
+ raise NotImplementedError('saving image is not supported during training.')
+ else:
+ if self.center_frame_only: # vimeo-90k
+ clip_ = val_data['lq_path'].split('/')[-3]
+ seq_ = val_data['lq_path'].split('/')[-2]
+ name_ = f'{clip_}_{seq_}'
+ img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f"{name_}_{self.opt['name']}.png")
+ else: # others
+ img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
+ f"{idx:08d}_{self.opt['name']}.png")
+ # image name only for REDS dataset
+ imwrite(result_img, img_path)
+
+ # calculate metrics
+ if with_metrics:
+ for metric_idx, opt_ in enumerate(self.opt['val']['metrics'].values()):
+ result = calculate_metric(metric_data, opt_)
+ self.metric_results[folder][idx, metric_idx] += result
+
+ # progress bar
+ if rank == 0:
+ for _ in range(world_size):
+ pbar.update(1)
+ pbar.set_description(f'Folder: {folder}')
+
+ if rank == 0:
+ pbar.close()
+
+ if with_metrics:
+ if self.opt['dist']:
+ # collect data among GPUs
+ for _, tensor in self.metric_results.items():
+ dist.reduce(tensor, 0)
+ dist.barrier()
+
+ if rank == 0:
+ self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
+
+ def test(self):
+ n = self.lq.size(1)
+ self.net_g.eval()
+
+ flip_seq = self.opt['val'].get('flip_seq', False)
+ self.center_frame_only = self.opt['val'].get('center_frame_only', False)
+
+ if flip_seq:
+ self.lq = torch.cat([self.lq, self.lq.flip(1)], dim=1)
+
+ with torch.no_grad():
+ self.output = self.net_g(self.lq)
+
+ if flip_seq:
+ output_1 = self.output[:, :n, :, :, :]
+ output_2 = self.output[:, n:, :, :, :].flip(1)
+ self.output = 0.5 * (output_1 + output_2)
+
+ if self.center_frame_only:
+ self.output = self.output[:, n // 2, :, :, :]
+
+ self.net_g.train()
diff --git a/StableSR/basicsr/ops/__init__.py b/StableSR/basicsr/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StableSR/basicsr/ops/dcn/__init__.py b/StableSR/basicsr/ops/dcn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..32e3592f896d61b4127e09d0476381b9d55e32ff
--- /dev/null
+++ b/StableSR/basicsr/ops/dcn/__init__.py
@@ -0,0 +1,7 @@
+from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
+ modulated_deform_conv)
+
+__all__ = [
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
+ 'modulated_deform_conv'
+]
diff --git a/StableSR/basicsr/ops/dcn/deform_conv.py b/StableSR/basicsr/ops/dcn/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..6268ca825d59ef4a30d4d2156c4438cbbe9b3c1e
--- /dev/null
+++ b/StableSR/basicsr/ops/dcn/deform_conv.py
@@ -0,0 +1,379 @@
+import math
+import os
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair, _single
+
+BASICSR_JIT = os.getenv('BASICSR_JIT')
+if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ deform_conv_ext = load(
+ 'deform_conv',
+ sources=[
+ os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
+ ],
+ )
+else:
+ try:
+ from . import deform_conv_ext
+ except ImportError:
+ pass
+ # avoid annoying print output
+ # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+ # '1. compile with BASICSR_EXT=True. or\n '
+ # '2. set BASICSR_JIT=True during running')
+
+
+class DeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ im2col_step=64):
+ if input is not None and input.dim() != 4:
+ raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.')
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.im2col_step = im2col_step
+
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ deform_conv_ext.deform_conv_forward(input, weight,
+ offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
+ ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+ cur_im2col_step)
+
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(input, weight, padding, dilation, stride):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = padding[d]
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})')
+ return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1):
+ ctx.stride = stride
+ ctx.padding = padding
+ ctx.dilation = dilation
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(1) # fake tensor
+ if not input.is_cuda:
+ raise NotImplementedError
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
+
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+ width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+ return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=False):
+ super(DeformConv, self).__init__()
+
+ assert not bias
+ assert in_channels % groups == 0, f'in_channels {in_channels} is not divisible by groups {groups}'
+ assert out_channels % groups == 0, f'out_channels {out_channels} is not divisible by groups {groups}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+
+ def forward(self, x, offset):
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
+ return out
+
+
+class DeformConvPack(DeformConv):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+
+
+class ModulatedDeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=True):
+ super(ModulatedDeformConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ self.with_bias = bias
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConvPack, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
diff --git a/StableSR/basicsr/ops/fused_act/__init__.py b/StableSR/basicsr/ops/fused_act/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..241dc0754fae7d88dbbd9a02e665ca30a73c7422
--- /dev/null
+++ b/StableSR/basicsr/ops/fused_act/__init__.py
@@ -0,0 +1,3 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+
+__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
diff --git a/StableSR/basicsr/ops/fused_act/fused_act.py b/StableSR/basicsr/ops/fused_act/fused_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..88edc445484b71119dc22a258e83aef49ce39b07
--- /dev/null
+++ b/StableSR/basicsr/ops/fused_act/fused_act.py
@@ -0,0 +1,95 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+import os
+import torch
+from torch import nn
+from torch.autograd import Function
+
+BASICSR_JIT = os.getenv('BASICSR_JIT')
+if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ fused_act_ext = load(
+ 'fused',
+ sources=[
+ os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
+ os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
+ ],
+ )
+else:
+ try:
+ from . import fused_act_ext
+ except ImportError:
+ pass
+ # avoid annoying print output
+ # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+ # '1. compile with BASICSR_EXT=True. or\n '
+ # '2. set BASICSR_JIT=True during running')
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+ gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
+ ctx.scale)
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/StableSR/basicsr/ops/upfirdn2d/__init__.py b/StableSR/basicsr/ops/upfirdn2d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..397e85bea063e97fc4c12ad4d3e15669b69290bd
--- /dev/null
+++ b/StableSR/basicsr/ops/upfirdn2d/__init__.py
@@ -0,0 +1,3 @@
+from .upfirdn2d import upfirdn2d
+
+__all__ = ['upfirdn2d']
diff --git a/StableSR/basicsr/ops/upfirdn2d/upfirdn2d.py b/StableSR/basicsr/ops/upfirdn2d/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6122d59aa32fd52e956bd36200ba79af4a17b17
--- /dev/null
+++ b/StableSR/basicsr/ops/upfirdn2d/upfirdn2d.py
@@ -0,0 +1,192 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
+
+import os
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+BASICSR_JIT = os.getenv('BASICSR_JIT')
+if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ upfirdn2d_ext = load(
+ 'upfirdn2d',
+ sources=[
+ os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
+ os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
+ ],
+ )
+else:
+ try:
+ from . import upfirdn2d_ext
+ except ImportError:
+ pass
+ # avoid annoying print output
+ # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n '
+ # '1. compile with BASICSR_EXT=True. or\n '
+ # '2. set BASICSR_JIT=True during running')
+
+
+class UpFirDn2dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
+
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_ext.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ down_x,
+ down_y,
+ up_x,
+ up_y,
+ g_pad_x0,
+ g_pad_x1,
+ g_pad_y0,
+ g_pad_y1,
+ )
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ ctx.up_x,
+ ctx.up_y,
+ ctx.down_x,
+ ctx.down_y,
+ ctx.pad_x0,
+ ctx.pad_x1,
+ ctx.pad_y0,
+ ctx.pad_y1,
+ )
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+ # ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ _, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ if input.device.type == 'cpu':
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+ else:
+ out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
+
+ return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/StableSR/basicsr/test.py b/StableSR/basicsr/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cb3b7aa860c90518e15ba76e1a55fdf404bcc2
--- /dev/null
+++ b/StableSR/basicsr/test.py
@@ -0,0 +1,45 @@
+import logging
+import torch
+from os import path as osp
+
+from basicsr.data import build_dataloader, build_dataset
+from basicsr.models import build_model
+from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs
+from basicsr.utils.options import dict2str, parse_options
+
+
+def test_pipeline(root_path):
+ # parse options, set distributed setting, set ramdom seed
+ opt, _ = parse_options(root_path, is_train=False)
+
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+
+ # mkdir and initialize loggers
+ make_exp_dirs(opt)
+ log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(opt))
+
+ # create test dataset and dataloader
+ test_loaders = []
+ for _, dataset_opt in sorted(opt['datasets'].items()):
+ test_set = build_dataset(dataset_opt)
+ test_loader = build_dataloader(
+ test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+ logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
+ test_loaders.append(test_loader)
+
+ # create model
+ model = build_model(opt)
+
+ for test_loader in test_loaders:
+ test_set_name = test_loader.dataset.opt['name']
+ logger.info(f'Testing {test_set_name}...')
+ model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
+
+
+if __name__ == '__main__':
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+ test_pipeline(root_path)
diff --git a/StableSR/basicsr/train.py b/StableSR/basicsr/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e02d98fe07f8c2924dda5b49f95adfa21990fa91
--- /dev/null
+++ b/StableSR/basicsr/train.py
@@ -0,0 +1,215 @@
+import datetime
+import logging
+import math
+import time
+import torch
+from os import path as osp
+
+from basicsr.data import build_dataloader, build_dataset
+from basicsr.data.data_sampler import EnlargedSampler
+from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
+from basicsr.models import build_model
+from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
+ init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
+from basicsr.utils.options import copy_opt_file, dict2str, parse_options
+
+
+def init_tb_loggers(opt):
+ # initialize wandb logger before tensorboard logger to allow proper sync
+ if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
+ is not None) and ('debug' not in opt['name']):
+ assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
+ init_wandb_logger(opt)
+ tb_logger = None
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
+ tb_logger = init_tb_logger(log_dir=osp.join(opt['root_path'], 'tb_logger', opt['name']))
+ return tb_logger
+
+
+def create_train_val_dataloader(opt, logger):
+ # create train and val dataloaders
+ train_loader, val_loaders = None, []
+ for phase, dataset_opt in opt['datasets'].items():
+ if phase == 'train':
+ dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
+ train_set = build_dataset(dataset_opt)
+ train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
+ train_loader = build_dataloader(
+ train_set,
+ dataset_opt,
+ num_gpu=opt['num_gpu'],
+ dist=opt['dist'],
+ sampler=train_sampler,
+ seed=opt['manual_seed'])
+
+ num_iter_per_epoch = math.ceil(
+ len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
+ total_iters = int(opt['train']['total_iter'])
+ total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
+ logger.info('Training statistics:'
+ f'\n\tNumber of train images: {len(train_set)}'
+ f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
+ f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
+ f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
+ elif phase.split('_')[0] == 'val':
+ val_set = build_dataset(dataset_opt)
+ val_loader = build_dataloader(
+ val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+ logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
+ val_loaders.append(val_loader)
+ else:
+ raise ValueError(f'Dataset phase {phase} is not recognized.')
+
+ return train_loader, train_sampler, val_loaders, total_epochs, total_iters
+
+
+def load_resume_state(opt):
+ resume_state_path = None
+ if opt['auto_resume']:
+ state_path = osp.join('experiments', opt['name'], 'training_states')
+ if osp.isdir(state_path):
+ states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
+ if len(states) != 0:
+ states = [float(v.split('.state')[0]) for v in states]
+ resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
+ opt['path']['resume_state'] = resume_state_path
+ else:
+ if opt['path'].get('resume_state'):
+ resume_state_path = opt['path']['resume_state']
+
+ if resume_state_path is None:
+ resume_state = None
+ else:
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
+ check_resume(opt, resume_state['iter'])
+ return resume_state
+
+
+def train_pipeline(root_path):
+ # parse options, set distributed setting, set random seed
+ opt, args = parse_options(root_path, is_train=True)
+ opt['root_path'] = root_path
+
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+
+ # load resume states if necessary
+ resume_state = load_resume_state(opt)
+ # mkdir for experiments and logger
+ if resume_state is None:
+ make_exp_dirs(opt)
+ if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name'] and opt['rank'] == 0:
+ mkdir_and_rename(osp.join(opt['root_path'], 'tb_logger', opt['name']))
+
+ # copy the yml file to the experiment root
+ copy_opt_file(args.opt, opt['path']['experiments_root'])
+
+ # WARNING: should not use get_root_logger in the above codes, including the called functions
+ # Otherwise the logger will not be properly initialized
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(opt))
+ # initialize wandb and tb loggers
+ tb_logger = init_tb_loggers(opt)
+
+ # create train and validation dataloaders
+ result = create_train_val_dataloader(opt, logger)
+ train_loader, train_sampler, val_loaders, total_epochs, total_iters = result
+
+ # create model
+ model = build_model(opt)
+ if resume_state: # resume training
+ model.resume_training(resume_state) # handle optimizers and schedulers
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.")
+ start_epoch = resume_state['epoch']
+ current_iter = resume_state['iter']
+ else:
+ start_epoch = 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ # dataloader prefetcher
+ prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
+ if prefetch_mode is None or prefetch_mode == 'cpu':
+ prefetcher = CPUPrefetcher(train_loader)
+ elif prefetch_mode == 'cuda':
+ prefetcher = CUDAPrefetcher(train_loader, opt)
+ logger.info(f'Use {prefetch_mode} prefetch dataloader')
+ if opt['datasets']['train'].get('pin_memory') is not True:
+ raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
+ else:
+ raise ValueError(f"Wrong prefetch_mode {prefetch_mode}. Supported ones are: None, 'cuda', 'cpu'.")
+
+ # training
+ logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
+ data_timer, iter_timer = AvgTimer(), AvgTimer()
+ start_time = time.time()
+
+ for epoch in range(start_epoch, total_epochs + 1):
+ train_sampler.set_epoch(epoch)
+ prefetcher.reset()
+ train_data = prefetcher.next()
+
+ while train_data is not None:
+ data_timer.record()
+
+ current_iter += 1
+ if current_iter > total_iters:
+ break
+ # update learning rate
+ model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
+ # training
+ model.feed_data(train_data)
+ model.optimize_parameters(current_iter)
+ iter_timer.record()
+ if current_iter == 1:
+ # reset start time in msg_logger for more accurate eta_time
+ # not work in resume mode
+ msg_logger.reset_start_time()
+ # log
+ if current_iter % opt['logger']['print_freq'] == 0:
+ log_vars = {'epoch': epoch, 'iter': current_iter}
+ log_vars.update({'lrs': model.get_current_learning_rate()})
+ log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ # save models and training states
+ if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
+ logger.info('Saving models and training states.')
+ model.save(epoch, current_iter)
+
+ # validation
+ if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
+ if len(val_loaders) > 1:
+ logger.warning('Multiple validation datasets are *only* supported by SRModel.')
+ for val_loader in val_loaders:
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+
+ data_timer.start()
+ iter_timer.start()
+ train_data = prefetcher.next()
+ # end of iter
+
+ # end of epoch
+
+ consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
+ logger.info(f'End of training. Time consumed: {consumed_time}')
+ logger.info('Save the latest model.')
+ model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
+ if opt.get('val') is not None:
+ for val_loader in val_loaders:
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+ if tb_logger:
+ tb_logger.close()
+
+
+if __name__ == '__main__':
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+ train_pipeline(root_path)
diff --git a/StableSR/basicsr/utils/__init__.py b/StableSR/basicsr/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9569c50780415b356c8e06edac5d960cf1fe1e91
--- /dev/null
+++ b/StableSR/basicsr/utils/__init__.py
@@ -0,0 +1,47 @@
+from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
+from .diffjpeg import DiffJPEG
+from .file_client import FileClient
+from .img_process_util import USMSharp, usm_sharp
+from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
+from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
+from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
+from .options import yaml_load
+
+__all__ = [
+ # color_util.py
+ 'bgr2ycbcr',
+ 'rgb2ycbcr',
+ 'rgb2ycbcr_pt',
+ 'ycbcr2bgr',
+ 'ycbcr2rgb',
+ # file_client.py
+ 'FileClient',
+ # img_util.py
+ 'img2tensor',
+ 'tensor2img',
+ 'imfrombytes',
+ 'imwrite',
+ 'crop_border',
+ # logger.py
+ 'MessageLogger',
+ 'AvgTimer',
+ 'init_tb_logger',
+ 'init_wandb_logger',
+ 'get_root_logger',
+ 'get_env_info',
+ # misc.py
+ 'set_random_seed',
+ 'get_time_str',
+ 'mkdir_and_rename',
+ 'make_exp_dirs',
+ 'scandir',
+ 'check_resume',
+ 'sizeof_fmt',
+ # diffjpeg
+ 'DiffJPEG',
+ # img_process_util
+ 'USMSharp',
+ 'usm_sharp',
+ # options
+ 'yaml_load'
+]
diff --git a/StableSR/basicsr/utils/color_util.py b/StableSR/basicsr/utils/color_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..4740d5c98dd0680654e20d46b81ab30dfe936d6e
--- /dev/null
+++ b/StableSR/basicsr/utils/color_util.py
@@ -0,0 +1,208 @@
+import numpy as np
+import torch
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace conversion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
+
+
+def rgb2ycbcr_pt(img, y_only=False):
+ """Convert RGB images to YCbCr images (PyTorch version).
+
+ It implements the ITU-R BT.601 conversion for standard-definition television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ Args:
+ img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format.
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float.
+ """
+ if y_only:
+ weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img)
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0
+ else:
+ weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img)
+ bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
+ out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
+
+ out_img = out_img / 255.
+ return out_img
diff --git a/StableSR/basicsr/utils/diffjpeg.py b/StableSR/basicsr/utils/diffjpeg.py
new file mode 100644
index 0000000000000000000000000000000000000000..65f96b44f9e7f3f8a589668f0003adf328cc5742
--- /dev/null
+++ b/StableSR/basicsr/utils/diffjpeg.py
@@ -0,0 +1,515 @@
+"""
+Modified from https://github.com/mlomnitz/DiffJPEG
+
+For images not divisible by 8
+https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
+"""
+import itertools
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+# ------------------------ utils ------------------------#
+y_table = np.array(
+ [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
+ [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
+ [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
+ dtype=np.float32).T
+y_table = nn.Parameter(torch.from_numpy(y_table))
+c_table = np.empty((8, 8), dtype=np.float32)
+c_table.fill(99)
+c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
+c_table = nn.Parameter(torch.from_numpy(c_table))
+
+
+def diff_round(x):
+ """ Differentiable rounding function
+ """
+ return torch.round(x) + (x - torch.round(x))**3
+
+
+def quality_to_factor(quality):
+ """ Calculate factor corresponding to quality
+
+ Args:
+ quality(float): Quality for jpeg compression.
+
+ Returns:
+ float: Compression factor.
+ """
+ if quality < 50:
+ quality = 5000. / quality
+ else:
+ quality = 200. - quality * 2
+ return quality / 100.
+
+
+# ------------------------ compression ------------------------#
+class RGB2YCbCrJpeg(nn.Module):
+ """ Converts RGB image to YCbCr
+ """
+
+ def __init__(self):
+ super(RGB2YCbCrJpeg, self).__init__()
+ matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
+ dtype=np.float32).T
+ self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+ def forward(self, image):
+ """
+ Args:
+ image(Tensor): batch x 3 x height x width
+
+ Returns:
+ Tensor: batch x height x width x 3
+ """
+ image = image.permute(0, 2, 3, 1)
+ result = torch.tensordot(image, self.matrix, dims=1) + self.shift
+ return result.view(image.shape)
+
+
+class ChromaSubsampling(nn.Module):
+ """ Chroma subsampling on CbCr channels
+ """
+
+ def __init__(self):
+ super(ChromaSubsampling, self).__init__()
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width x 3
+
+ Returns:
+ y(tensor): batch x height x width
+ cb(tensor): batch x height/2 x width/2
+ cr(tensor): batch x height/2 x width/2
+ """
+ image_2 = image.permute(0, 3, 1, 2).clone()
+ cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
+ cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
+ cb = cb.permute(0, 2, 3, 1)
+ cr = cr.permute(0, 2, 3, 1)
+ return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
+
+
+class BlockSplitting(nn.Module):
+ """ Splitting image into patches
+ """
+
+ def __init__(self):
+ super(BlockSplitting, self).__init__()
+ self.k = 8
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x h*w/64 x h x w
+ """
+ height, _ = image.shape[1:3]
+ batch_size = image.shape[0]
+ image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+ return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
+
+
+class DCT8x8(nn.Module):
+ """ Discrete Cosine Transformation
+ """
+
+ def __init__(self):
+ super(DCT8x8, self).__init__()
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+ for x, y, u, v in itertools.product(range(8), repeat=4):
+ tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+ self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ image = image - 128
+ result = self.scale * torch.tensordot(image, self.tensor, dims=2)
+ result.view(image.shape)
+ return result
+
+
+class YQuantize(nn.Module):
+ """ JPEG Quantization for Y channel
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding):
+ super(YQuantize, self).__init__()
+ self.rounding = rounding
+ self.y_table = y_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ image = image.float() / (self.y_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ image = image.float() / table
+ image = self.rounding(image)
+ return image
+
+
+class CQuantize(nn.Module):
+ """ JPEG Quantization for CbCr channels
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding):
+ super(CQuantize, self).__init__()
+ self.rounding = rounding
+ self.c_table = c_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ image = image.float() / (self.c_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ image = image.float() / table
+ image = self.rounding(image)
+ return image
+
+
+class CompressJpeg(nn.Module):
+ """Full JPEG compression algorithm
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding=torch.round):
+ super(CompressJpeg, self).__init__()
+ self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
+ self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
+ self.c_quantize = CQuantize(rounding=rounding)
+ self.y_quantize = YQuantize(rounding=rounding)
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x 3 x height x width
+
+ Returns:
+ dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
+ """
+ y, cb, cr = self.l1(image * 255)
+ components = {'y': y, 'cb': cb, 'cr': cr}
+ for k in components.keys():
+ comp = self.l2(components[k])
+ if k in ('cb', 'cr'):
+ comp = self.c_quantize(comp, factor=factor)
+ else:
+ comp = self.y_quantize(comp, factor=factor)
+
+ components[k] = comp
+
+ return components['y'], components['cb'], components['cr']
+
+
+# ------------------------ decompression ------------------------#
+
+
+class YDequantize(nn.Module):
+ """Dequantize Y channel
+ """
+
+ def __init__(self):
+ super(YDequantize, self).__init__()
+ self.y_table = y_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ out = image * (self.y_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ out = image * table
+ return out
+
+
+class CDequantize(nn.Module):
+ """Dequantize CbCr channel
+ """
+
+ def __init__(self):
+ super(CDequantize, self).__init__()
+ self.c_table = c_table
+
+ def forward(self, image, factor=1):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ if isinstance(factor, (int, float)):
+ out = image * (self.c_table * factor)
+ else:
+ b = factor.size(0)
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
+ out = image * table
+ return out
+
+
+class iDCT8x8(nn.Module):
+ """Inverse discrete Cosine Transformation
+ """
+
+ def __init__(self):
+ super(iDCT8x8, self).__init__()
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
+ for x, y, u, v in itertools.product(range(8), repeat=4):
+ tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ image = image * self.alpha
+ result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
+ result.view(image.shape)
+ return result
+
+
+class BlockMerging(nn.Module):
+ """Merge patches into image
+ """
+
+ def __init__(self):
+ super(BlockMerging, self).__init__()
+
+ def forward(self, patches, height, width):
+ """
+ Args:
+ patches(tensor) batch x height*width/64, height x width
+ height(int)
+ width(int)
+
+ Returns:
+ Tensor: batch x height x width
+ """
+ k = 8
+ batch_size = patches.shape[0]
+ image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
+ return image_transposed.contiguous().view(batch_size, height, width)
+
+
+class ChromaUpsampling(nn.Module):
+ """Upsample chroma layers
+ """
+
+ def __init__(self):
+ super(ChromaUpsampling, self).__init__()
+
+ def forward(self, y, cb, cr):
+ """
+ Args:
+ y(tensor): y channel image
+ cb(tensor): cb channel
+ cr(tensor): cr channel
+
+ Returns:
+ Tensor: batch x height x width x 3
+ """
+
+ def repeat(x, k=2):
+ height, width = x.shape[1:3]
+ x = x.unsqueeze(-1)
+ x = x.repeat(1, 1, k, k)
+ x = x.view(-1, height * k, width * k)
+ return x
+
+ cb = repeat(cb)
+ cr = repeat(cr)
+ return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
+
+
+class YCbCr2RGBJpeg(nn.Module):
+ """Converts YCbCr image to RGB JPEG
+ """
+
+ def __init__(self):
+ super(YCbCr2RGBJpeg, self).__init__()
+
+ matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
+ self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
+
+ def forward(self, image):
+ """
+ Args:
+ image(tensor): batch x height x width x 3
+
+ Returns:
+ Tensor: batch x 3 x height x width
+ """
+ result = torch.tensordot(image + self.shift, self.matrix, dims=1)
+ return result.view(image.shape).permute(0, 3, 1, 2)
+
+
+class DeCompressJpeg(nn.Module):
+ """Full JPEG decompression algorithm
+
+ Args:
+ rounding(function): rounding function to use
+ """
+
+ def __init__(self, rounding=torch.round):
+ super(DeCompressJpeg, self).__init__()
+ self.c_dequantize = CDequantize()
+ self.y_dequantize = YDequantize()
+ self.idct = iDCT8x8()
+ self.merging = BlockMerging()
+ self.chroma = ChromaUpsampling()
+ self.colors = YCbCr2RGBJpeg()
+
+ def forward(self, y, cb, cr, imgh, imgw, factor=1):
+ """
+ Args:
+ compressed(dict(tensor)): batch x h*w/64 x 8 x 8
+ imgh(int)
+ imgw(int)
+ factor(float)
+
+ Returns:
+ Tensor: batch x 3 x height x width
+ """
+ components = {'y': y, 'cb': cb, 'cr': cr}
+ for k in components.keys():
+ if k in ('cb', 'cr'):
+ comp = self.c_dequantize(components[k], factor=factor)
+ height, width = int(imgh / 2), int(imgw / 2)
+ else:
+ comp = self.y_dequantize(components[k], factor=factor)
+ height, width = imgh, imgw
+ comp = self.idct(comp)
+ components[k] = self.merging(comp, height, width)
+ #
+ image = self.chroma(components['y'], components['cb'], components['cr'])
+ image = self.colors(image)
+
+ image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
+ return image / 255
+
+
+# ------------------------ main DiffJPEG ------------------------ #
+
+
+class DiffJPEG(nn.Module):
+ """This JPEG algorithm result is slightly different from cv2.
+ DiffJPEG supports batch processing.
+
+ Args:
+ differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
+ """
+
+ def __init__(self, differentiable=True):
+ super(DiffJPEG, self).__init__()
+ if differentiable:
+ rounding = diff_round
+ else:
+ rounding = torch.round
+
+ self.compress = CompressJpeg(rounding=rounding)
+ self.decompress = DeCompressJpeg(rounding=rounding)
+
+ def forward(self, x, quality):
+ """
+ Args:
+ x (Tensor): Input image, bchw, rgb, [0, 1]
+ quality(float): Quality factor for jpeg compression scheme.
+ """
+ factor = quality
+ if isinstance(factor, (int, float)):
+ factor = quality_to_factor(factor)
+ else:
+ for i in range(factor.size(0)):
+ factor[i] = quality_to_factor(factor[i])
+ h, w = x.size()[-2:]
+ h_pad, w_pad = 0, 0
+ # why should use 16
+ if h % 16 != 0:
+ h_pad = 16 - h % 16
+ if w % 16 != 0:
+ w_pad = 16 - w % 16
+ x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
+
+ y, cb, cr = self.compress(x, factor=factor)
+ recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
+ recovered = recovered[:, :, 0:h, 0:w]
+ return recovered
+
+
+if __name__ == '__main__':
+ import cv2
+
+ from basicsr.utils import img2tensor, tensor2img
+
+ img_gt = cv2.imread('test.png') / 255.
+
+ # -------------- cv2 -------------- #
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
+ _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
+ img_lq = np.float32(cv2.imdecode(encimg, 1))
+ cv2.imwrite('cv2_JPEG_20.png', img_lq)
+
+ # -------------- DiffJPEG -------------- #
+ jpeger = DiffJPEG(differentiable=False).cuda()
+ img_gt = img2tensor(img_gt)
+ img_gt = torch.stack([img_gt, img_gt]).cuda()
+ quality = img_gt.new_tensor([20, 40])
+ out = jpeger(img_gt, quality=quality)
+
+ cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
+ cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
diff --git a/StableSR/basicsr/utils/dist_util.py b/StableSR/basicsr/utils/dist_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fab887b2cb1ce8533d2e8fdee72ae0c24f68fd0
--- /dev/null
+++ b/StableSR/basicsr/utils/dist_util.py
@@ -0,0 +1,82 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/StableSR/basicsr/utils/download_util.py b/StableSR/basicsr/utils/download_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f73abd0e1831b8cab6277d780331a5103785b9ec
--- /dev/null
+++ b/StableSR/basicsr/utils/download_util.py
@@ -0,0 +1,98 @@
+import math
+import os
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+from urllib.parse import urlparse
+
+from .misc import sizeof_fmt
+
+
+def download_file_from_google_drive(file_id, save_path):
+ """Download files from google drive.
+
+ Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive
+
+ Args:
+ file_id (str): File id.
+ save_path (str): Save path.
+ """
+
+ session = requests.Session()
+ URL = 'https://docs.google.com/uc?export=download'
+ params = {'id': file_id}
+
+ response = session.get(URL, params=params, stream=True)
+ token = get_confirm_token(response)
+ if token:
+ params['confirm'] = token
+ response = session.get(URL, params=params, stream=True)
+
+ # get file size
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+ if 'Content-Range' in response_file_size.headers:
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
+ else:
+ file_size = None
+
+ save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith('download_warning'):
+ return value
+ return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+ if file_size is not None:
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+ readable_file_size = sizeof_fmt(file_size)
+ else:
+ pbar = None
+
+ with open(destination, 'wb') as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size):
+ downloaded_size += chunk_size
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if pbar is not None:
+ pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+
+ Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+
+ Returns:
+ str: The path to the downloaded file.
+ """
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
diff --git a/StableSR/basicsr/utils/file_client.py b/StableSR/basicsr/utils/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..89d83ab9e0d4314f8cdf2393908a561c6d1dca92
--- /dev/null
+++ b/StableSR/basicsr/utils/file_client.py
@@ -0,0 +1,167 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'r') as f:
+ value_buf = f.read()
+ return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_paths (str | list[str]): Lmdb database paths.
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_paths (list): Lmdb database path.
+ _client (list): A list of several lmdb envs.
+ """
+
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ if isinstance(client_keys, str):
+ client_keys = [client_keys]
+
+ if isinstance(db_paths, list):
+ self.db_paths = [str(v) for v in db_paths]
+ elif isinstance(db_paths, str):
+ self.db_paths = [str(db_paths)]
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+ self._client = {}
+ for client, path in zip(client_keys, self.db_paths):
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+ def get(self, filepath, client_key):
+ """Get values according to the filepath from one lmdb named client_key.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ client_key (str): Used for distinguishing different lmdb envs.
+ """
+ filepath = str(filepath)
+ assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
+ client = self._client[client_key]
+ with client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class FileClient(object):
+ """A general file client to access files in different backend.
+
+ The client loads a file or text in a specified backend from its path
+ and return it as a binary file. it can also register other backend
+ accessor with a given name and backend class.
+
+ Attributes:
+ backend (str): The storage backend type. Options are "disk",
+ "memcached" and "lmdb".
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ }
+
+ def __init__(self, backend='disk', **kwargs):
+ if backend not in self._backends:
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(self._backends.keys())}')
+ self.backend = backend
+ self.client = self._backends[backend](**kwargs)
+
+ def get(self, filepath, client_key='default'):
+ # client_key is used only for lmdb, where different fileclients have
+ # different lmdb environments.
+ if self.backend == 'lmdb':
+ return self.client.get(filepath, client_key)
+ else:
+ return self.client.get(filepath)
+
+ def get_text(self, filepath):
+ return self.client.get_text(filepath)
diff --git a/StableSR/basicsr/utils/flow_util.py b/StableSR/basicsr/utils/flow_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d7180b4e9b5c8f2eb36a9a0e4ff6affdaae84b8
--- /dev/null
+++ b/StableSR/basicsr/utils/flow_util.py
@@ -0,0 +1,170 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
+import cv2
+import numpy as np
+import os
+
+
+def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
+ """Read an optical flow map.
+
+ Args:
+ flow_path (ndarray or str): Flow path.
+ quantize (bool): whether to read quantized pair, if set to True,
+ remaining args will be passed to :func:`dequantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+
+ Returns:
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
+ """
+ if quantize:
+ assert concat_axis in [0, 1]
+ cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
+ if cat_flow.ndim != 2:
+ raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
+ assert cat_flow.shape[concat_axis] % 2 == 0
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
+ else:
+ with open(flow_path, 'rb') as f:
+ try:
+ header = f.read(4).decode('utf-8')
+ except Exception:
+ raise IOError(f'Invalid flow file: {flow_path}')
+ else:
+ if header != 'PIEH':
+ raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
+
+ w = np.fromfile(f, np.int32, 1).squeeze()
+ h = np.fromfile(f, np.int32, 1).squeeze()
+ flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+
+ return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+ """Write optical flow to file.
+
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+ will be concatenated horizontally into a single image if quantize is True.)
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ filename (str): Output filepath.
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+ images. If set to True, remaining args will be passed to
+ :func:`quantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ """
+ if not quantize:
+ with open(filename, 'wb') as f:
+ f.write('PIEH'.encode('utf-8'))
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+ flow = flow.astype(np.float32)
+ flow.tofile(f)
+ f.flush()
+ else:
+ assert concat_axis in [0, 1]
+ dx, dy = quantize_flow(flow, *args, **kwargs)
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+ cv2.imwrite(filename, dxdy)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+ """Quantize flow to [0, 255].
+
+ After this step, the size of flow will be much smaller, and can be
+ dumped as jpeg images.
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ max_val (float): Maximum value of flow, values beyond
+ [-max_val, max_val] will be truncated.
+ norm (bool): Whether to divide flow values by image width/height.
+
+ Returns:
+ tuple[ndarray]: Quantized dx and dy.
+ """
+ h, w, _ = flow.shape
+ dx = flow[..., 0]
+ dy = flow[..., 1]
+ if norm:
+ dx = dx / w # avoid inplace operations
+ dy = dy / h
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+ flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
+ return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+ """Recover from quantized flow.
+
+ Args:
+ dx (ndarray): Quantized dx.
+ dy (ndarray): Quantized dy.
+ max_val (float): Maximum value used when quantizing.
+ denorm (bool): Whether to multiply flow values with width/height.
+
+ Returns:
+ ndarray: Dequantized flow.
+ """
+ assert dx.shape == dy.shape
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+ if denorm:
+ dx *= dx.shape[1]
+ dy *= dx.shape[0]
+ flow = np.dstack((dx, dy))
+ return flow
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+ return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+ """Dequantize an array.
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
+
+ return dequantized_arr
diff --git a/StableSR/basicsr/utils/img_process_util.py b/StableSR/basicsr/utils/img_process_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e02f09930dbf13bcd12bbe16b76e4fce52578e
--- /dev/null
+++ b/StableSR/basicsr/utils/img_process_util.py
@@ -0,0 +1,83 @@
+import cv2
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+
+def filter2D(img, kernel):
+ """PyTorch version of cv2.filter2D
+
+ Args:
+ img (Tensor): (b, c, h, w)
+ kernel (Tensor): (b, k, k)
+ """
+ k = kernel.size(-1)
+ b, c, h, w = img.size()
+ if k % 2 == 1:
+ img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
+ else:
+ raise ValueError('Wrong kernel size')
+
+ ph, pw = img.size()[-2:]
+
+ if kernel.size(0) == 1:
+ # apply the same kernel to all batch images
+ img = img.view(b * c, 1, ph, pw)
+ kernel = kernel.view(1, 1, k, k)
+ return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
+ else:
+ img = img.view(1, b * c, ph, pw)
+ kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
+ return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
+
+
+def usm_sharp(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening.
+
+ Input image: I; Blurry image: B.
+ 1. sharp = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * sharp + (1 - Mask) * I
+
+
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ sharp = img + weight * residual
+ sharp = np.clip(sharp, 0, 1)
+ return soft_mask * sharp + (1 - soft_mask) * img
+
+
+class USMSharp(torch.nn.Module):
+
+ def __init__(self, radius=50, sigma=0):
+ super(USMSharp, self).__init__()
+ if radius % 2 == 0:
+ radius += 1
+ self.radius = radius
+ kernel = cv2.getGaussianKernel(radius, sigma)
+ kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
+ self.register_buffer('kernel', kernel)
+
+ def forward(self, img, weight=0.5, threshold=10):
+ blur = filter2D(img, self.kernel)
+ residual = img - blur
+
+ mask = torch.abs(residual) * 255 > threshold
+ mask = mask.float()
+ soft_mask = filter2D(mask, self.kernel)
+ sharp = img + weight * residual
+ sharp = torch.clip(sharp, 0, 1)
+ return soft_mask * sharp + (1 - soft_mask) * img
diff --git a/StableSR/basicsr/utils/img_util.py b/StableSR/basicsr/utils/img_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbce5dba5b01deb78f2453edc801a76e6a126998
--- /dev/null
+++ b/StableSR/basicsr/utils/img_util.py
@@ -0,0 +1,172 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1 and torch.is_tensor(tensor):
+ result = result[0]
+ return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+ """This implementation is slightly faster than tensor2img.
+ It now only supports torch tensor with shape (1, c, h, w).
+
+ Args:
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+ min_max (tuple[int]): min and max values for clamp.
+ """
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+ output = output.type(torch.uint8).cpu().numpy()
+ if rgb2bgr:
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+ return output
+
+
+def imfrombytes(content, flag='color', float32=False):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale` and `unchanged`.
+ float32 (bool): Whether to change to float32., If True, will also norm
+ to [0, 1]. Default: False.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+ img_np = np.frombuffer(content, np.uint8)
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+ img = cv2.imdecode(img_np, imread_flags[flag])
+ if float32:
+ img = img.astype(np.float32) / 255.
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ ok = cv2.imwrite(file_path, img, params)
+ if not ok:
+ raise IOError('Failed in writing images.')
+
+
+def crop_border(imgs, crop_border):
+ """Crop borders of images.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+ crop_border (int): Crop border for each end of height and weight.
+
+ Returns:
+ list[ndarray]: Cropped images.
+ """
+ if crop_border == 0:
+ return imgs
+ else:
+ if isinstance(imgs, list):
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+ else:
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
diff --git a/StableSR/basicsr/utils/lmdb_util.py b/StableSR/basicsr/utils/lmdb_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2b45ce01d5e32ddbf8354d71fd1c8678bede822
--- /dev/null
+++ b/StableSR/basicsr/utils/lmdb_util.py
@@ -0,0 +1,199 @@
+import cv2
+import lmdb
+import sys
+from multiprocessing import Pool
+from os import path as osp
+from tqdm import tqdm
+
+
+def make_lmdb_from_imgs(data_path,
+ lmdb_path,
+ img_path_list,
+ keys,
+ batch=5000,
+ compress_level=1,
+ multiprocessing_read=False,
+ n_thread=40,
+ map_size=None):
+ """Make lmdb from images.
+
+ Contents of lmdb. The file structure is:
+
+ ::
+
+ example.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records 1)image name (with extension),
+ 2)image shape, and 3)compression level, separated by a white space.
+
+ For example, the meta information could be:
+ `000_00000000.png (720,1280,3) 1`, which means:
+ 1) image name (with extension): 000_00000000.png;
+ 2) image shape: (720,1280,3);
+ 3) compression level: 1
+
+ We use the image name without extension as the lmdb key.
+
+ If `multiprocessing_read` is True, it will read all the images to memory
+ using multiprocessing. Thus, your server needs to have enough memory.
+
+ Args:
+ data_path (str): Data path for reading images.
+ lmdb_path (str): Lmdb save path.
+ img_path_list (str): Image path list.
+ keys (str): Used for lmdb keys.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ multiprocessing_read (bool): Whether use multiprocessing to read all
+ the images to memory. Default: False.
+ n_thread (int): For multiprocessing.
+ map_size (int | None): Map size for lmdb env. If None, use the
+ estimated size from images. Default: None
+ """
+
+ assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
+ f'but got {len(img_path_list)} and {len(keys)}')
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
+ print(f'Totoal images: {len(img_path_list)}')
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ if multiprocessing_read:
+ # read all the images to memory (multiprocessing)
+ dataset = {} # use dict to keep the order for multiprocessing
+ shapes = {}
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
+ pbar = tqdm(total=len(img_path_list), unit='image')
+
+ def callback(arg):
+ """get the image data and update pbar."""
+ key, dataset[key], shapes[key] = arg
+ pbar.update(1)
+ pbar.set_description(f'Read {key}')
+
+ pool = Pool(n_thread)
+ for path, key in zip(img_path_list, keys):
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
+ pool.close()
+ pool.join()
+ pbar.close()
+ print(f'Finish reading {len(img_path_list)} images.')
+
+ # create lmdb environment
+ if map_size is None:
+ # obtain data size for one image
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ data_size_per_img = img_byte.nbytes
+ print('Data size per image is: ', data_size_per_img)
+ data_size = data_size_per_img * len(img_path_list)
+ map_size = data_size * 10
+
+ env = lmdb.open(lmdb_path, map_size=map_size)
+
+ # write data to lmdb
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
+ txn = env.begin(write=True)
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
+ pbar.update(1)
+ pbar.set_description(f'Write {key}')
+ key_byte = key.encode('ascii')
+ if multiprocessing_read:
+ img_byte = dataset[key]
+ h, w, c = shapes[key]
+ else:
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
+ h, w, c = img_shape
+
+ txn.put(key_byte, img_byte)
+ # write meta information
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
+ if idx % batch == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ pbar.close()
+ txn.commit()
+ env.close()
+ txt_file.close()
+ print('\nFinish writing lmdb.')
+
+
+def read_img_worker(path, key, compress_level):
+ """Read image worker.
+
+ Args:
+ path (str): Image path.
+ key (str): Image key.
+ compress_level (int): Compress level when encoding images.
+
+ Returns:
+ str: Image key.
+ byte: Image byte.
+ tuple[int]: Image shape.
+ """
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if img.ndim == 2:
+ h, w = img.shape
+ c = 1
+ else:
+ h, w, c = img.shape
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ return (key, img_byte, (h, w, c))
+
+
+class LmdbMaker():
+ """LMDB Maker.
+
+ Args:
+ lmdb_path (str): Lmdb save path.
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ self.lmdb_path = lmdb_path
+ self.batch = batch
+ self.compress_level = compress_level
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
+ self.txn = self.env.begin(write=True)
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ self.counter = 0
+
+ def put(self, img_byte, key, img_shape):
+ self.counter += 1
+ key_byte = key.encode('ascii')
+ self.txn.put(key_byte, img_byte)
+ # write meta information
+ h, w, c = img_shape
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
+ if self.counter % self.batch == 0:
+ self.txn.commit()
+ self.txn = self.env.begin(write=True)
+
+ def close(self):
+ self.txn.commit()
+ self.env.close()
+ self.txt_file.close()
diff --git a/StableSR/basicsr/utils/logger.py b/StableSR/basicsr/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..73553dc664781a061737e94880ea1c6788c09043
--- /dev/null
+++ b/StableSR/basicsr/utils/logger.py
@@ -0,0 +1,213 @@
+import datetime
+import logging
+import time
+
+from .dist_util import get_dist_info, master_only
+
+initialized_logger = {}
+
+
+class AvgTimer():
+
+ def __init__(self, window=200):
+ self.window = window # average window
+ self.current_time = 0
+ self.total_time = 0
+ self.count = 0
+ self.avg_time = 0
+ self.start()
+
+ def start(self):
+ self.start_time = self.tic = time.time()
+
+ def record(self):
+ self.count += 1
+ self.toc = time.time()
+ self.current_time = self.toc - self.tic
+ self.total_time += self.current_time
+ # calculate average time
+ self.avg_time = self.total_time / self.count
+
+ # reset
+ if self.count > self.window:
+ self.count = 0
+ self.total_time = 0
+
+ self.tic = time.time()
+
+ def get_current_time(self):
+ return self.current_time
+
+ def get_avg_time(self):
+ return self.avg_time
+
+
+class MessageLogger():
+ """Message logger for printing.
+
+ Args:
+ opt (dict): Config. It contains the following keys:
+ name (str): Exp name.
+ logger (dict): Contains 'print_freq' (str) for logger interval.
+ train (dict): Contains 'total_iter' (int) for total iters.
+ use_tb_logger (bool): Use tensorboard logger.
+ start_iter (int): Start iter. Default: 1.
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
+ """
+
+ def __init__(self, opt, start_iter=1, tb_logger=None):
+ self.exp_name = opt['name']
+ self.interval = opt['logger']['print_freq']
+ self.start_iter = start_iter
+ self.max_iters = opt['train']['total_iter']
+ self.use_tb_logger = opt['logger']['use_tb_logger']
+ self.tb_logger = tb_logger
+ self.start_time = time.time()
+ self.logger = get_root_logger()
+
+ def reset_start_time(self):
+ self.start_time = time.time()
+
+ @master_only
+ def __call__(self, log_vars):
+ """Format logging message.
+
+ Args:
+ log_vars (dict): It contains the following keys:
+ epoch (int): Epoch number.
+ iter (int): Current iter.
+ lrs (list): List for learning rates.
+
+ time (float): Iter time.
+ data_time (float): Data time for each iter.
+ """
+ # epoch, iter, learning rates
+ epoch = log_vars.pop('epoch')
+ current_iter = log_vars.pop('iter')
+ lrs = log_vars.pop('lrs')
+
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
+ for v in lrs:
+ message += f'{v:.3e},'
+ message += ')] '
+
+ # time and estimated time
+ if 'time' in log_vars.keys():
+ iter_time = log_vars.pop('time')
+ data_time = log_vars.pop('data_time')
+
+ total_time = time.time() - self.start_time
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ message += f'[eta: {eta_str}, '
+ message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
+
+ # other items, especially losses
+ for k, v in log_vars.items():
+ message += f'{k}: {v:.4e} '
+ # tensorboard logger
+ if self.use_tb_logger and 'debug' not in self.exp_name:
+ if k.startswith('l_'):
+ self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
+ else:
+ self.tb_logger.add_scalar(k, v, current_iter)
+ self.logger.info(message)
+
+
+@master_only
+def init_tb_logger(log_dir):
+ from torch.utils.tensorboard import SummaryWriter
+ tb_logger = SummaryWriter(log_dir=log_dir)
+ return tb_logger
+
+
+@master_only
+def init_wandb_logger(opt):
+ """We now only use wandb to sync tensorboard log."""
+ import wandb
+ logger = get_root_logger()
+
+ project = opt['logger']['wandb']['project']
+ resume_id = opt['logger']['wandb'].get('resume_id')
+ if resume_id:
+ wandb_id = resume_id
+ resume = 'allow'
+ logger.warning(f'Resume wandb logger with id={wandb_id}.')
+ else:
+ wandb_id = wandb.util.generate_id()
+ resume = 'never'
+
+ wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
+
+ logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
+
+
+def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
+ """Get the root logger.
+
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added.
+
+ Args:
+ logger_name (str): root logger name. Default: 'basicsr'.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = logging.getLogger(logger_name)
+ # if the logger has been initialized, just return it
+ if logger_name in initialized_logger:
+ return logger
+
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(logging.Formatter(format_str))
+ logger.addHandler(stream_handler)
+ logger.propagate = False
+ rank, _ = get_dist_info()
+ if rank != 0:
+ logger.setLevel('ERROR')
+ elif log_file is not None:
+ logger.setLevel(log_level)
+ # add file handler
+ file_handler = logging.FileHandler(log_file, 'w')
+ file_handler.setFormatter(logging.Formatter(format_str))
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
+ initialized_logger[logger_name] = True
+ return logger
+
+
+def get_env_info():
+ """Get environment information.
+
+ Currently, only log the software version.
+ """
+ import torch
+ import torchvision
+
+ from basicsr.version import __version__
+ msg = r"""
+ ____ _ _____ ____
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
+ ______ __ __ __ __
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
+ """
+ msg += ('\nVersion Information: '
+ f'\n\tBasicSR: {__version__}'
+ f'\n\tPyTorch: {torch.__version__}'
+ f'\n\tTorchVision: {torchvision.__version__}')
+ return msg
diff --git a/StableSR/basicsr/utils/matlab_functions.py b/StableSR/basicsr/utils/matlab_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..a201f79aaf030cdba710dd97c28af1b29a93ed2a
--- /dev/null
+++ b/StableSR/basicsr/utils/matlab_functions.py
@@ -0,0 +1,178 @@
+import math
+import numpy as np
+import torch
+
+
+def cubic(x):
+ """cubic function used for calculate_weights_indices."""
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
+ (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ """Calculate weights and indices, used for imresize function.
+
+ Args:
+ in_length (int): Input length.
+ out_length (int): Output length.
+ scale (float): Scale factor.
+ kernel_width (int): Kernel width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ """
+
+ if (scale < 1) and antialiasing:
+ # Use a modified kernel (larger kernel width) to simultaneously
+ # interpolate and antialias
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ p = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
+ out_length, p)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
+
+ # apply cubic kernel
+ if (scale < 1) and antialiasing:
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, p)
+
+ # If a column in weights is all zero, get rid of it. only consider the
+ # first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, p - 2)
+ weights = weights.narrow(1, 1, p - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, p - 2)
+ weights = weights.narrow(1, 0, p - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+@torch.no_grad()
+def imresize(img, scale, antialiasing=True):
+ """imresize function same as MATLAB.
+
+ It now only supports bicubic.
+ The same scale applies for both height and width.
+
+ Args:
+ img (Tensor | Numpy array):
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
+ scale (float): Scale factor. The same scale applies for both height
+ and width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ Default: True.
+
+ Returns:
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
+ """
+ squeeze_flag = False
+ if type(img).__module__ == np.__name__: # numpy type
+ numpy_type = True
+ if img.ndim == 2:
+ img = img[:, :, None]
+ squeeze_flag = True
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+ else:
+ numpy_type = False
+ if img.ndim == 2:
+ img = img.unsqueeze(0)
+ squeeze_flag = True
+
+ in_c, in_h, in_w = img.size()
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # get weights and indices
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
+ antialiasing)
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
+ antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
+
+ sym_patch = img[:, :sym_len_hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_he:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
+ kernel_width = weights_h.size(1)
+ for i in range(out_h):
+ idx = int(indices_h[i][0])
+ for j in range(in_c):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_we:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
+ kernel_width = weights_w.size(1)
+ for i in range(out_w):
+ idx = int(indices_w[i][0])
+ for j in range(in_c):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
+
+ if squeeze_flag:
+ out_2 = out_2.squeeze(0)
+ if numpy_type:
+ out_2 = out_2.numpy()
+ if not squeeze_flag:
+ out_2 = out_2.transpose(1, 2, 0)
+
+ return out_2
diff --git a/StableSR/basicsr/utils/misc.py b/StableSR/basicsr/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8d4a1403509672e85e74ac476e028cefb6dbb62
--- /dev/null
+++ b/StableSR/basicsr/utils/misc.py
@@ -0,0 +1,141 @@
+import numpy as np
+import os
+import random
+import time
+import torch
+from os import path as osp
+
+from .dist_util import master_only
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + '_archived_' + get_time_str()
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+
+
+@master_only
+def make_exp_dirs(opt):
+ """Make dirs for experiments."""
+ path_opt = opt['path'].copy()
+ if opt['is_train']:
+ mkdir_and_rename(path_opt.pop('experiments_root'))
+ else:
+ mkdir_and_rename(path_opt.pop('results_root'))
+ for key, path in path_opt.items():
+ if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
+ continue
+ else:
+ os.makedirs(path, exist_ok=True)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def check_resume(opt, resume_iter):
+ """Check resume states and pretrain_network paths.
+
+ Args:
+ opt (dict): Options.
+ resume_iter (int): Resume iteration.
+ """
+ if opt['path']['resume_state']:
+ # get all the networks
+ networks = [key for key in opt.keys() if key.startswith('network_')]
+ flag_pretrain = False
+ for network in networks:
+ if opt['path'].get(f'pretrain_{network}') is not None:
+ flag_pretrain = True
+ if flag_pretrain:
+ print('pretrain_network path will be ignored during resuming.')
+ # set pretrained model paths
+ for network in networks:
+ name = f'pretrain_{network}'
+ basename = network.replace('network_', '')
+ if opt['path'].get('ignore_resume_networks') is None or (network
+ not in opt['path']['ignore_resume_networks']):
+ opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
+ print(f"Set {name} to {opt['path'][name]}")
+
+ # change param_key to params in resume
+ param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
+ for param_key in param_keys:
+ if opt['path'][param_key] == 'params_ema':
+ opt['path'][param_key] = 'params'
+ print(f'Set {param_key} to params')
+
+
+def sizeof_fmt(size, suffix='B'):
+ """Get human readable file size.
+
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+
+ Return:
+ str: Formatted file size.
+ """
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+ if abs(size) < 1024.0:
+ return f'{size:3.1f} {unit}{suffix}'
+ size /= 1024.0
+ return f'{size:3.1f} Y{suffix}'
diff --git a/StableSR/basicsr/utils/options.py b/StableSR/basicsr/utils/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..3afd79c4f3e73f44f36503288c3959125ac3df34
--- /dev/null
+++ b/StableSR/basicsr/utils/options.py
@@ -0,0 +1,210 @@
+import argparse
+import os
+import random
+import torch
+import yaml
+from collections import OrderedDict
+from os import path as osp
+
+from basicsr.utils import set_random_seed
+from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
+
+
+def ordered_yaml():
+ """Support OrderedDict for yaml.
+
+ Returns:
+ tuple: yaml Loader and Dumper.
+ """
+ try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+ except ImportError:
+ from yaml import Dumper, Loader
+
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+
+def yaml_load(f):
+ """Load yaml file or string.
+
+ Args:
+ f (str): File path or a python string.
+
+ Returns:
+ dict: Loaded dict.
+ """
+ if os.path.isfile(f):
+ with open(f, 'r') as f:
+ return yaml.load(f, Loader=ordered_yaml()[0])
+ else:
+ return yaml.load(f, Loader=ordered_yaml()[0])
+
+
+def dict2str(opt, indent_level=1):
+ """dict to string for printing options.
+
+ Args:
+ opt (dict): Option dict.
+ indent_level (int): Indent level. Default: 1.
+
+ Return:
+ (str): Option string for printing.
+ """
+ msg = '\n'
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_level * 2) + k + ':['
+ msg += dict2str(v, indent_level + 1)
+ msg += ' ' * (indent_level * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+ return msg
+
+
+def _postprocess_yml_value(value):
+ # None
+ if value == '~' or value.lower() == 'none':
+ return None
+ # bool
+ if value.lower() == 'true':
+ return True
+ elif value.lower() == 'false':
+ return False
+ # !!float number
+ if value.startswith('!!float'):
+ return float(value.replace('!!float', ''))
+ # number
+ if value.isdigit():
+ return int(value)
+ elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
+ return float(value)
+ # list
+ if value.startswith('['):
+ return eval(value)
+ # str
+ return value
+
+
+def parse_options(root_path, is_train=True):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
+ parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
+ parser.add_argument('--auto_resume', action='store_true')
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--local_rank', type=int, default=0)
+ parser.add_argument(
+ '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
+ args = parser.parse_args()
+
+ # parse yml to dict
+ opt = yaml_load(args.opt)
+
+ # distributed settings
+ if args.launcher == 'none':
+ opt['dist'] = False
+ print('Disable distributed.', flush=True)
+ else:
+ opt['dist'] = True
+ if args.launcher == 'slurm' and 'dist_params' in opt:
+ init_dist(args.launcher, **opt['dist_params'])
+ else:
+ init_dist(args.launcher)
+ opt['rank'], opt['world_size'] = get_dist_info()
+
+ # random seed
+ seed = opt.get('manual_seed')
+ if seed is None:
+ seed = random.randint(1, 10000)
+ opt['manual_seed'] = seed
+ set_random_seed(seed + opt['rank'])
+
+ # force to update yml options
+ if args.force_yml is not None:
+ for entry in args.force_yml:
+ # now do not support creating new keys
+ keys, value = entry.split('=')
+ keys, value = keys.strip(), value.strip()
+ value = _postprocess_yml_value(value)
+ eval_str = 'opt'
+ for key in keys.split(':'):
+ eval_str += f'["{key}"]'
+ eval_str += '=value'
+ # using exec function
+ exec(eval_str)
+
+ opt['auto_resume'] = args.auto_resume
+ opt['is_train'] = is_train
+
+ # debug setting
+ if args.debug and not opt['name'].startswith('debug'):
+ opt['name'] = 'debug_' + opt['name']
+
+ if opt['num_gpu'] == 'auto':
+ opt['num_gpu'] = torch.cuda.device_count()
+
+ # datasets
+ for phase, dataset in opt['datasets'].items():
+ # for multiple datasets, e.g., val_1, val_2; test_1, test_2
+ phase = phase.split('_')[0]
+ dataset['phase'] = phase
+ if 'scale' in opt:
+ dataset['scale'] = opt['scale']
+ if dataset.get('dataroot_gt') is not None:
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
+ if dataset.get('dataroot_lq') is not None:
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
+
+ # paths
+ for key, val in opt['path'].items():
+ if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
+ opt['path'][key] = osp.expanduser(val)
+
+ if is_train:
+ experiments_root = osp.join(root_path, 'experiments', opt['name'])
+ opt['path']['experiments_root'] = experiments_root
+ opt['path']['models'] = osp.join(experiments_root, 'models')
+ opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
+ opt['path']['log'] = experiments_root
+ opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
+
+ # change some options for debug mode
+ if 'debug' in opt['name']:
+ if 'val' in opt:
+ opt['val']['val_freq'] = 8
+ opt['logger']['print_freq'] = 1
+ opt['logger']['save_checkpoint_freq'] = 8
+ else: # test
+ results_root = osp.join(root_path, 'results', opt['name'])
+ opt['path']['results_root'] = results_root
+ opt['path']['log'] = results_root
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+ return opt, args
+
+
+@master_only
+def copy_opt_file(opt_file, experiments_root):
+ # copy the yml file to the experiment root
+ import sys
+ import time
+ from shutil import copyfile
+ cmd = ' '.join(sys.argv)
+ filename = osp.join(experiments_root, osp.basename(opt_file))
+ copyfile(opt_file, filename)
+
+ with open(filename, 'r+') as f:
+ lines = f.readlines()
+ lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
+ f.seek(0)
+ f.writelines(lines)
diff --git a/StableSR/basicsr/utils/plot_util.py b/StableSR/basicsr/utils/plot_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e6da5bc29e706da87ab83af6d5367176fe78763
--- /dev/null
+++ b/StableSR/basicsr/utils/plot_util.py
@@ -0,0 +1,83 @@
+import re
+
+
+def read_data_from_tensorboard(log_path, tag):
+ """Get raw data (steps and values) from tensorboard events.
+
+ Args:
+ log_path (str): Path to the tensorboard log.
+ tag (str): tag to be read.
+ """
+ from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
+
+ # tensorboard event
+ event_acc = EventAccumulator(log_path)
+ event_acc.Reload()
+ scalar_list = event_acc.Tags()['scalars']
+ print('tag list: ', scalar_list)
+ steps = [int(s.step) for s in event_acc.Scalars(tag)]
+ values = [s.value for s in event_acc.Scalars(tag)]
+ return steps, values
+
+
+def read_data_from_txt_2v(path, pattern, step_one=False):
+ """Read data from txt with 2 returned values (usually [step, value]).
+
+ Args:
+ path (str): path to the txt file.
+ pattern (str): re (regular expression) pattern.
+ step_one (bool): add 1 to steps. Default: False.
+ """
+ with open(path) as f:
+ lines = f.readlines()
+ lines = [line.strip() for line in lines]
+ steps = []
+ values = []
+
+ pattern = re.compile(pattern)
+ for line in lines:
+ match = pattern.match(line)
+ if match:
+ steps.append(int(match.group(1)))
+ values.append(float(match.group(2)))
+ if step_one:
+ steps = [v + 1 for v in steps]
+ return steps, values
+
+
+def read_data_from_txt_1v(path, pattern):
+ """Read data from txt with 1 returned values.
+
+ Args:
+ path (str): path to the txt file.
+ pattern (str): re (regular expression) pattern.
+ """
+ with open(path) as f:
+ lines = f.readlines()
+ lines = [line.strip() for line in lines]
+ data = []
+
+ pattern = re.compile(pattern)
+ for line in lines:
+ match = pattern.match(line)
+ if match:
+ data.append(float(match.group(1)))
+ return data
+
+
+def smooth_data(values, smooth_weight):
+ """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
+
+ Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501
+
+ Args:
+ values (list): A list of values to be smoothed.
+ smooth_weight (float): Smooth weight.
+ """
+ values_sm = []
+ last_sm_value = values[0]
+ for value in values:
+ value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value
+ values_sm.append(value_sm)
+ last_sm_value = value_sm
+ return values_sm
diff --git a/StableSR/basicsr/utils/realesrgan_utils.py b/StableSR/basicsr/utils/realesrgan_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff934e5150b4aa568a51ab9614a2057b011a6014
--- /dev/null
+++ b/StableSR/basicsr/utils/realesrgan_utils.py
@@ -0,0 +1,293 @@
+import cv2
+import math
+import numpy as np
+import os
+import queue
+import threading
+import torch
+from basicsr.utils.download_util import load_file_from_url
+from torch.nn import functional as F
+
+# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+class RealESRGANer():
+ """A helper class for upsampling images with RealESRGAN.
+
+ Args:
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+ model (nn.Module): The defined network. Default: None.
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
+ 0 denotes for do not use tile. Default: 0.
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+ half (float): Whether to use half precision during inference. Default: False.
+ """
+
+ def __init__(self,
+ scale,
+ model_path,
+ model=None,
+ tile=0,
+ tile_pad=10,
+ pre_pad=10,
+ half=False,
+ device=None,
+ gpu_id=None):
+ self.scale = scale
+ self.tile_size = tile
+ self.tile_pad = tile_pad
+ self.pre_pad = pre_pad
+ self.mod_scale = None
+ self.half = half
+
+ # initialize model
+ if gpu_id:
+ self.device = torch.device(
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
+ else:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+ # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
+ if model_path.startswith('https://'):
+ model_path = load_file_from_url(
+ url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
+ # prefer to use params_ema
+ if 'params_ema' in loadnet:
+ keyname = 'params_ema'
+ else:
+ keyname = 'params'
+ model.load_state_dict(loadnet[keyname], strict=True)
+ model.eval()
+ self.model = model.to(self.device)
+ if self.half:
+ self.model = self.model.half()
+
+ def pre_process(self, img):
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
+ """
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
+ self.img = img.unsqueeze(0).to(self.device)
+ if self.half:
+ self.img = self.img.half()
+
+ # pre_pad
+ if self.pre_pad != 0:
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
+ # mod pad for divisible borders
+ if self.scale == 2:
+ self.mod_scale = 2
+ elif self.scale == 1:
+ self.mod_scale = 4
+ if self.mod_scale is not None:
+ self.mod_pad_h, self.mod_pad_w = 0, 0
+ _, _, h, w = self.img.size()
+ if (h % self.mod_scale != 0):
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
+ if (w % self.mod_scale != 0):
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
+
+ def process(self):
+ # model inference
+ self.output = self.model(self.img)
+
+ def tile_process(self):
+ """It will first crop input images to tiles, and then process each tile.
+ Finally, all the processed tiles are merged into one images.
+
+ Modified from: https://github.com/ata4/esrgan-launcher
+ """
+ batch, channel, height, width = self.img.shape
+ output_height = height * self.scale
+ output_width = width * self.scale
+ output_shape = (batch, channel, output_height, output_width)
+
+ # start with black image
+ self.output = self.img.new_zeros(output_shape)
+ tiles_x = math.ceil(width / self.tile_size)
+ tiles_y = math.ceil(height / self.tile_size)
+
+ # loop over all tiles
+ for y in range(tiles_y):
+ for x in range(tiles_x):
+ # extract tile from input image
+ ofs_x = x * self.tile_size
+ ofs_y = y * self.tile_size
+ # input tile area on total image
+ input_start_x = ofs_x
+ input_end_x = min(ofs_x + self.tile_size, width)
+ input_start_y = ofs_y
+ input_end_y = min(ofs_y + self.tile_size, height)
+
+ # input tile area on total image with padding
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
+
+ # input tile dimensions
+ input_tile_width = input_end_x - input_start_x
+ input_tile_height = input_end_y - input_start_y
+ tile_idx = y * tiles_x + x + 1
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
+
+ # upscale tile
+ try:
+ with torch.no_grad():
+ output_tile = self.model(input_tile)
+ except RuntimeError as error:
+ print('Error', error)
+ # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
+
+ # output tile area on total image
+ output_start_x = input_start_x * self.scale
+ output_end_x = input_end_x * self.scale
+ output_start_y = input_start_y * self.scale
+ output_end_y = input_end_y * self.scale
+
+ # output tile area without padding
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
+
+ # put tile into output image
+ self.output[:, :, output_start_y:output_end_y,
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
+ output_start_x_tile:output_end_x_tile]
+
+ def post_process(self):
+ # remove extra pad
+ if self.mod_scale is not None:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
+ # remove prepad
+ if self.pre_pad != 0:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
+ return self.output
+
+ @torch.no_grad()
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
+ h_input, w_input = img.shape[0:2]
+ # img: numpy
+ img = img.astype(np.float32)
+ if np.max(img) > 256: # 16-bit image
+ max_range = 65535
+ print('\tInput is a 16-bit image')
+ else:
+ max_range = 255
+ img = img / max_range
+ if len(img.shape) == 2: # gray image
+ img_mode = 'L'
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ elif img.shape[2] == 4: # RGBA image with alpha channel
+ img_mode = 'RGBA'
+ alpha = img[:, :, 3]
+ img = img[:, :, 0:3]
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ if alpha_upsampler == 'realesrgan':
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
+ else:
+ img_mode = 'RGB'
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # ------------------- process image (without the alpha channel) ------------------- #
+ self.pre_process(img)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_img = self.post_process()
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
+ if img_mode == 'L':
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+
+ # ------------------- process the alpha channel if necessary ------------------- #
+ if img_mode == 'RGBA':
+ if alpha_upsampler == 'realesrgan':
+ self.pre_process(alpha)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_alpha = self.post_process()
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
+ else: # use the cv2 resize for alpha channel
+ h, w = alpha.shape[0:2]
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
+
+ # merge the alpha channel
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
+ output_img[:, :, 3] = output_alpha
+
+ # ------------------------------ return ------------------------------ #
+ if max_range == 65535: # 16-bit image
+ output = (output_img * 65535.0).round().astype(np.uint16)
+ else:
+ output = (output_img * 255.0).round().astype(np.uint8)
+
+ if outscale is not None and outscale != float(self.scale):
+ output = cv2.resize(
+ output, (
+ int(w_input * outscale),
+ int(h_input * outscale),
+ ), interpolation=cv2.INTER_LANCZOS4)
+
+ return output, img_mode
+
+
+class PrefetchReader(threading.Thread):
+ """Prefetch images.
+
+ Args:
+ img_list (list[str]): A image list of image paths to be read.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, img_list, num_prefetch_queue):
+ super().__init__()
+ self.que = queue.Queue(num_prefetch_queue)
+ self.img_list = img_list
+
+ def run(self):
+ for img_path in self.img_list:
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+ self.que.put(img)
+
+ self.que.put(None)
+
+ def __next__(self):
+ next_item = self.que.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class IOConsumer(threading.Thread):
+
+ def __init__(self, opt, que, qid):
+ super().__init__()
+ self._queue = que
+ self.qid = qid
+ self.opt = opt
+
+ def run(self):
+ while True:
+ msg = self._queue.get()
+ if isinstance(msg, str) and msg == 'quit':
+ break
+
+ output = msg['output']
+ save_path = msg['save_path']
+ cv2.imwrite(save_path, output)
+ print(f'IO worker {self.qid} is done.')
diff --git a/StableSR/basicsr/utils/registry.py b/StableSR/basicsr/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e72ef7ff21b94f50e6caa8948f69ca0b04bc968
--- /dev/null
+++ b/StableSR/basicsr/utils/registry.py
@@ -0,0 +1,88 @@
+# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
+
+
+class Registry():
+ """
+ The registry that provides name -> object mapping, to support third-party
+ users' custom modules.
+
+ To create a registry (e.g. a backbone registry):
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY = Registry('BACKBONE')
+
+ To register an object:
+
+ .. code-block:: python
+
+ @BACKBONE_REGISTRY.register()
+ class MyBackbone():
+ ...
+
+ Or:
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY.register(MyBackbone)
+ """
+
+ def __init__(self, name):
+ """
+ Args:
+ name (str): the name of this registry
+ """
+ self._name = name
+ self._obj_map = {}
+
+ def _do_register(self, name, obj, suffix=None):
+ if isinstance(suffix, str):
+ name = name + '_' + suffix
+
+ assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
+ f"in '{self._name}' registry!")
+ self._obj_map[name] = obj
+
+ def register(self, obj=None, suffix=None):
+ """
+ Register the given object under the the name `obj.__name__`.
+ Can be used as either a decorator or not.
+ See docstring of this class for usage.
+ """
+ if obj is None:
+ # used as a decorator
+ def deco(func_or_class):
+ name = func_or_class.__name__
+ self._do_register(name, func_or_class, suffix)
+ return func_or_class
+
+ return deco
+
+ # used as a function call
+ name = obj.__name__
+ self._do_register(name, obj, suffix)
+
+ def get(self, name, suffix='basicsr'):
+ ret = self._obj_map.get(name)
+ if ret is None:
+ ret = self._obj_map.get(name + '_' + suffix)
+ print(f'Name {name} is not found, use name: {name}_{suffix}!')
+ if ret is None:
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
+ return ret
+
+ def __contains__(self, name):
+ return name in self._obj_map
+
+ def __iter__(self):
+ return iter(self._obj_map.items())
+
+ def keys(self):
+ return self._obj_map.keys()
+
+
+DATASET_REGISTRY = Registry('dataset')
+ARCH_REGISTRY = Registry('arch')
+MODEL_REGISTRY = Registry('model')
+LOSS_REGISTRY = Registry('loss')
+METRIC_REGISTRY = Registry('metric')
diff --git a/StableSR/clip/__init__.py b/StableSR/clip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9
--- /dev/null
+++ b/StableSR/clip/__init__.py
@@ -0,0 +1 @@
+from .clip import *
diff --git a/StableSR/clip/bpe_simple_vocab_16e6.txt.gz b/StableSR/clip/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/StableSR/clip/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/StableSR/clip/clip.py b/StableSR/clip/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7a5da5e69e0a3b41383734711ccfff1923a9ef9
--- /dev/null
+++ b/StableSR/clip/clip.py
@@ -0,0 +1,245 @@
+import hashlib
+import os
+import urllib
+import warnings
+from typing import Any, Union, List
+from pkg_resources import packaging
+
+import torch
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+from tqdm import tqdm
+
+from .model import build_model
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+try:
+ from torchvision.transforms import InterpolationMode
+ BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+ BICUBIC = Image.BICUBIC
+
+
+if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
+
+
+__all__ = ["available_models", "load", "tokenize"]
+_tokenizer = _Tokenizer()
+
+_MODELS = {
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
+}
+
+
+def _download(url: str, root: str):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def _convert_image_to_rgb(image):
+ return image.convert("RGB")
+
+
+def _transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=BICUBIC),
+ CenterCrop(n_px),
+ _convert_image_to_rgb,
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+
+def available_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list(_MODELS.keys())
+
+
+def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+
+ device : Union[str, torch.device]
+ The device to put the loaded model
+
+ jit : bool
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
+
+ download_root: str
+ path to download the model files; by default, it uses "~/.cache/clip"
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLIP model
+
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if name in _MODELS:
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+
+ with open(model_path, 'rb') as opened_file:
+ try:
+ # loading JIT archive
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(opened_file, map_location="cpu")
+
+ if not jit:
+ model = build_model(state_dict or model.state_dict()).to(device)
+ if str(device) == "cpu":
+ model.float()
+ return model, _transform(model.visual.input_resolution)
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def _node_get(node: torch._C.Node, key: str):
+ """Gets attributes of a node which is polymorphic over return type.
+
+ From https://github.com/pytorch/pytorch/pull/82628
+ """
+ sel = node.kindOf(key)
+ return getattr(node, sel)(key)
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if str(device) == "cpu":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if _node_get(inputs[i].node(), "value") == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, _transform(model.input_resolution.item())
+
+
+def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ truncate: bool
+ Whether to truncate the text in case its encoding is longer than the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+ else:
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ if truncate:
+ tokens = tokens[:context_length]
+ tokens[-1] = eot_token
+ else:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/StableSR/clip/model.py b/StableSR/clip/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..232b7792eb97440642547bd462cf128df9243933
--- /dev/null
+++ b/StableSR/clip/model.py
@@ -0,0 +1,436 @@
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu2 = nn.ReLU(inplace=True)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu3 = nn.ReLU(inplace=True)
+
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu1(self.bn1(self.conv1(x)))
+ out = self.relu2(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu3(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x[:1], key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+ return x.squeeze(0)
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.relu3 = nn.ReLU(inplace=True)
+ self.avgpool = nn.AvgPool2d(2)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ def stem(x):
+ x = self.relu1(self.bn1(self.conv1(x)))
+ x = self.relu2(self.bn2(self.conv2(x)))
+ x = self.relu3(self.bn3(self.conv3(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+ def forward(self, x: torch.Tensor):
+ return self.resblocks(x)
+
+
+class VisionTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def forward(self, x: torch.Tensor):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+class CLIP(nn.Module):
+ def __init__(self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ if isinstance(vision_layers, (tuple, list)):
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width
+ )
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisionTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask()
+ )
+
+ self.vocab_size = vocab_size
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+ self.initialize_parameters()
+
+ def initialize_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ if isinstance(self.visual, ModifiedResNet):
+ if self.visual.attnpool is not None:
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+ def encode_image(self, image):
+ return self.visual(image.type(self.dtype))
+
+ def encode_text(self, text):
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+
+ # normalized features
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logits_per_image.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
+
+ model = CLIP(
+ embed_dim,
+ image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ if key in state_dict:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict)
+ return model.eval()
diff --git a/StableSR/clip/simple_tokenizer.py b/StableSR/clip/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a66286b7d5019c6e221932a813768038f839c91
--- /dev/null
+++ b/StableSR/clip/simple_tokenizer.py
@@ -0,0 +1,132 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
diff --git a/StableSR/cog.yaml b/StableSR/cog.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..77fdfa97e5a9f1667e70a1d16ca53cb90786ddb4
--- /dev/null
+++ b/StableSR/cog.yaml
@@ -0,0 +1,32 @@
+# Configuration for Cog ⚙️
+# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
+
+build:
+ gpu: true
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_version: "3.11"
+ python_packages:
+ - "torch==2.0.1"
+ - "torchvision==0.15.2"
+ - "numpy==1.25.1"
+ - "opencv-python==4.8.0.74"
+ - "imageio==2.31.1"
+ - "omegaconf==2.3.0"
+ - "transformers==4.31.0"
+ - "torchmetrics==0.7.0"
+ - "open_clip_torch==2.0.2"
+ - "einops==0.6.1"
+ - "pytorch_lightning==1.7.7"
+ - "scipy==1.11.1"
+ - "scikit-image==0.21.0"
+ - "matplotlib==3.7.2"
+ - "scikit-learn==1.3.0"
+ - "kornia==0.6.12"
+ - "xformers==0.0.20"
+ - "clip @ git+https://github.com/openai/CLIP.git"
+ run:
+ - pip install git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+ - mkdir -p /root/.cache/torch/hub/checkpoints && wget --output-document "/root/.cache/torch/hub/checkpoints/vgg16-397923af.pth" "https://download.pytorch.org/models/vgg16-397923af.pth"
+predict: "predict.py:Predictor"
diff --git a/StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml b/StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..557f739ff5a82de37c21cbcea4fcaceee969e068
--- /dev/null
+++ b/StableSR/configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml
@@ -0,0 +1,73 @@
+model:
+ base_learning_rate: 5.0e-5
+ target: ldm.models.autoencoder.AutoencoderKLResi
+ params:
+ # for training only
+ # ckpt_path: /mnt/lustre/jywang/code/stable_diffmodels/v2-1_512-ema-pruned.ckpt
+ monitor: "val/rec_loss"
+ embed_dim: 4
+ fusion_w: 1.0
+ freeze_dec: True
+ synthesis_data: False
+ lossconfig:
+ target: ldm.modules.losses.LPIPSWithDiscriminator
+ params:
+ disc_start: 501
+ kl_weight: 0
+ disc_weight: 0.025
+ disc_factor: 1.0
+
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 512
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+
+ image_key: 'gt'
+
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 1
+ num_workers: 6
+ wrap: True
+ train:
+ target: basicsr.data.single_image_dataset.SingleImageNPDataset
+ params:
+ gt_path: ['/mnt/lustre/share/jywang/ddpm_data/CFW_trainingdata/']
+ io_backend:
+ type: disk
+ validation:
+ target: basicsr.data.single_image_dataset.SingleImageNPDataset
+ params:
+ gt_path: ['/mnt/lustre/share/jywang/ddpm_data/CFW_trainingdata/']
+ io_backend:
+ type: disk
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 1500
+ callbacks:
+ image_logger:
+ target: main.ImageLogger
+ params:
+ batch_frequency: 1500
+ max_images: 4
+ increase_log_steps: False
+
+ trainer:
+ benchmark: True
+ max_steps: 800000
+ accumulate_grad_batches: 8
diff --git a/StableSR/configs/stableSRNew/v2-finetune_text_T_512.yaml b/StableSR/configs/stableSRNew/v2-finetune_text_T_512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..52bcb8933cac28b44a26bf062b00843e065fdad1
--- /dev/null
+++ b/StableSR/configs/stableSRNew/v2-finetune_text_T_512.yaml
@@ -0,0 +1,247 @@
+sf: 4
+model:
+ base_learning_rate: 5.0e-05
+ target: ldm.models.diffusion.ddpm.LatentDiffusionSRTextWT
+ params:
+ # parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: image
+ cond_stage_key: caption
+ image_size: 512
+ channels: 4
+ cond_stage_trainable: False # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+ # for training only
+ # ckpt_path: /mnt/lustre/jywang/code/stable_diffmodels/v2-1_512-ema-pruned.ckpt
+ unfrozen_diff: False
+ random_size: False
+ time_replace: 1000
+ use_usm: True
+ #P2 weighting, we do not use in final version
+ p2_gamma: ~
+ p2_k: ~
+ # ignore_keys: []
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModelDualcondV2
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ use_checkpoint: False
+ legacy: False
+ semb_channels: 256
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ # for training only
+ # ckpt_path: /mnt/lustre/jywang/code/stable_diffmodels/v2-1_512-ema-pruned.ckpt
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 512
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
+ structcond_stage_config:
+ target: ldm.modules.diffusionmodules.openaimodel.EncoderUNetModelWT
+ params:
+ image_size: 96
+ in_channels: 4
+ model_channels: 256
+ out_channels: 256
+ num_res_blocks: 2
+ attention_resolutions: [ 4, 2, 1 ]
+ dropout: 0
+ channel_mult: [ 1, 1, 2, 2 ]
+ conv_resample: True
+ dims: 2
+ use_checkpoint: False
+ use_fp16: False
+ num_heads: 4
+ num_head_channels: -1
+ num_heads_upsample: -1
+ use_scale_shift_norm: False
+ resblock_updown: False
+ use_new_attention_order: False
+
+
+degradation:
+ # the first degradation process
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+ resize_range: [0.3, 1.5]
+ gaussian_noise_prob: 0.5
+ noise_range: [1, 15]
+ poisson_scale_range: [0.05, 2.0]
+ gray_noise_prob: 0.4
+ jpeg_range: [60, 95]
+
+ # the second degradation process
+ second_blur_prob: 0.5
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+ resize_range2: [0.6, 1.2]
+ gaussian_noise_prob2: 0.5
+ noise_range2: [1, 12]
+ poisson_scale_range2: [0.05, 1.0]
+ gray_noise_prob2: 0.4
+ jpeg_range2: [60, 100]
+
+ gt_size: 512
+ no_degradation_prob: 0.01
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 6
+ num_workers: 6
+ wrap: false
+ train:
+ target: basicsr.data.realesrgan_dataset.RealESRGANDataset
+ params:
+ queue_size: 180
+ gt_path: ['/mnt/lustre/share/jywang/dataset/DIV8K/train_HR/', '/mnt/lustre/share/jywang/dataset/df2k_ost/GT/']
+ face_gt_path: '/mnt/lustre/share/jywang/dataset/FFHQ/1024/'
+ num_face: 10000
+ crop_size: 512
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 1.5]
+ betag_range: [0.5, 2.0]
+ betap_range: [1, 1.5]
+
+ blur_kernel_size2: 11
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.0]
+ betag_range2: [0.5, 2.0]
+ betap_range2: [1, 1.5]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 512
+ use_hflip: True
+ use_rot: False
+ validation:
+ target: basicsr.data.realesrgan_dataset.RealESRGANDataset
+ params:
+ gt_path: /mnt/lustre/share/jywang/dataset/ImageSR/DIV2K/DIV2K_train_HR/
+ crop_size: 512
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 1.5]
+ betag_range: [0.5, 2.0]
+ betap_range: [1, 1.5]
+
+ blur_kernel_size2: 11
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.0]
+ betag_range2: [0.5, 2.0]
+ betap_range2: [1, 1.5]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 512
+ use_hflip: True
+ use_rot: False
+
+test_data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 1
+ num_workers: 6
+ wrap: false
+ test:
+ target: basicsr.data.realesrgan_dataset.RealESRGANDataset
+ params:
+ gt_path: /mnt/lustre/share/jywang/dataset/ImageSR/DIV2K/DIV2K_train_HR/
+ crop_size: 512
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 1.5]
+ betag_range: [0.5, 2.0]
+ betap_range: [1, 1.5]
+
+ blur_kernel_size2: 11
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.0]
+ betag_range2: [0.5, 2.0]
+ betap_range2: [1, 1.5]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 512
+ use_hflip: True
+ use_rot: False
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 1500
+ callbacks:
+ image_logger:
+ target: main.ImageLogger
+ params:
+ batch_frequency: 1500
+ max_images: 4
+ increase_log_steps: False
+
+ trainer:
+ benchmark: True
+ max_steps: 800000
+ accumulate_grad_batches: 4
diff --git a/StableSR/configs/stableSRNew/v2-finetune_text_T_768v.yaml b/StableSR/configs/stableSRNew/v2-finetune_text_T_768v.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ce2e65fd2d93c2d2e3c4998b925160db8e2660a8
--- /dev/null
+++ b/StableSR/configs/stableSRNew/v2-finetune_text_T_768v.yaml
@@ -0,0 +1,247 @@
+sf: 4
+model:
+ base_learning_rate: 5.0e-05
+ target: ldm.models.diffusion.ddpm.LatentDiffusionSRTextWT
+ params:
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: image
+ cond_stage_key: caption
+ image_size: 768
+ channels: 4
+ cond_stage_trainable: False # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+ # for training only
+ # ckpt_path: /mnt/lustre/jywang/code/stable_diffmodels/v2-1_768-ema-pruned.ckpt
+ unfrozen_diff: False
+ random_size: False
+ time_replace: 1000
+ use_usm: False
+ #P2 weighting, we do not use in final version
+ p2_gamma: ~
+ p2_k: ~
+ # ignore_keys: []
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModelDualcondV2
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ use_checkpoint: False
+ legacy: False
+ semb_channels: 256
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ # for training only
+ # ckpt_path: /mnt/lustre/jywang/code/stable_diffmodels/v2-1_768-ema-pruned.ckpt
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 768
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
+ structcond_stage_config:
+ target: ldm.modules.diffusionmodules.openaimodel.EncoderUNetModelWT
+ params:
+ image_size: 96
+ in_channels: 4
+ model_channels: 256
+ out_channels: 256
+ num_res_blocks: 2
+ attention_resolutions: [ 4, 2, 1 ]
+ dropout: 0
+ channel_mult: [ 1, 1, 2, 2 ]
+ conv_resample: True
+ dims: 2
+ use_checkpoint: False
+ use_fp16: False
+ num_heads: 4
+ num_head_channels: -1
+ num_heads_upsample: -1
+ use_scale_shift_norm: False
+ resblock_updown: False
+ use_new_attention_order: False
+
+
+degradation:
+ # the first degradation process
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
+ resize_range: [0.3, 1.5]
+ gaussian_noise_prob: 0.5
+ noise_range: [1, 15]
+ poisson_scale_range: [0.05, 2.0]
+ gray_noise_prob: 0.4
+ jpeg_range: [60, 95]
+
+ # the second degradation process
+ second_blur_prob: 0.5
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
+ resize_range2: [0.6, 1.2]
+ gaussian_noise_prob2: 0.5
+ noise_range2: [1, 12]
+ poisson_scale_range2: [0.05, 1.0]
+ gray_noise_prob2: 0.4
+ jpeg_range2: [60, 95]
+
+ gt_size: 768
+ no_degradation_prob: 0
+
+data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 3
+ num_workers: 6
+ wrap: false
+ train:
+ target: basicsr.data.realesrgan_dataset.RealESRGANDataset
+ params:
+ queue_size: 180
+ gt_path: ['/mnt/lustre/share/jywang/dataset/DIV8K/train_HR/', '/mnt/lustre/share/jywang/dataset/df2k_ost/GT/']
+ face_gt_path: ['/mnt/lustre/share/jywang/dataset/FFHQ/1024/', '/mnt/lustre/share/jywang/dataset/FFHQ/ffhq_wild/']
+ num_face: 5000
+ crop_size: 768
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 1.5]
+ betag_range: [0.5, 2.0]
+ betap_range: [1, 1.5]
+
+ blur_kernel_size2: 11
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.0]
+ betag_range2: [0.5, 2.0]
+ betap_range2: [1, 1.5]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 768
+ use_hflip: True
+ use_rot: False
+ validation:
+ target: basicsr.data.realesrgan_dataset.RealESRGANDataset
+ params:
+ gt_path: /mnt/lustre/share/jywang/dataset/ImageSR/DIV2K/DIV2K_train_HR/
+ crop_size: 768
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 1.5]
+ betag_range: [0.5, 2.0]
+ betap_range: [1, 1.5]
+
+ blur_kernel_size2: 11
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.0]
+ betag_range2: [0.5, 2.0]
+ betap_range2: [1, 1.5]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 768
+ use_hflip: True
+ use_rot: False
+
+test_data:
+ target: main.DataModuleFromConfig
+ params:
+ batch_size: 1
+ num_workers: 6
+ wrap: false
+ test:
+ target: basicsr.data.realesrgan_dataset.RealESRGANDataset
+ params:
+ gt_path: ['/mnt/lustre/jywang/dataset/ImageSR/Set5/HR/', '/mnt/lustre/jywang/dataset/ImageSR/Set14/HR/']
+ crop_size: 768
+ io_backend:
+ type: disk
+
+ blur_kernel_size: 21
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob: 0.1
+ blur_sigma: [0.2, 1.5]
+ betag_range: [0.5, 2.0]
+ betap_range: [1, 1.5]
+
+ blur_kernel_size2: 11
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
+ sinc_prob2: 0.1
+ blur_sigma2: [0.2, 1.0]
+ betag_range2: [0.5, 2.0]
+ betap_range2: [1, 1.5]
+
+ final_sinc_prob: 0.8
+
+ gt_size: 768
+ use_hflip: True
+ use_rot: False
+
+lightning:
+ modelcheckpoint:
+ params:
+ every_n_train_steps: 1000
+ callbacks:
+ image_logger:
+ target: main.ImageLogger
+ params:
+ batch_frequency: 1000
+ max_images: 2
+ increase_log_steps: False
+
+ trainer:
+ benchmark: True
+ max_steps: 800000
+ accumulate_grad_batches: 4
diff --git a/StableSR/environment.yaml b/StableSR/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f45f4b8d8e8b95bc2001aab63981d8e58457e20
--- /dev/null
+++ b/StableSR/environment.yaml
@@ -0,0 +1,32 @@
+name: stablesr
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - python=3.9
+ - pip=20.3
+ - cudatoolkit=11.3
+ - pytorch=1.12.1
+ - torchvision=0.13.1
+ - numpy=1.23.1
+ - pip:
+ - albumentations==1.3.0
+ - opencv-python==4.6.0.66
+ - imageio==2.9.0
+ - imageio-ffmpeg==0.4.2
+ - pytorch-lightning==1.4.2
+ - omegaconf==2.1.1
+ - test-tube>=0.7.5
+ - streamlit==1.12.1
+ - einops==0.3.0
+ - transformers==4.19.2
+ - webdataset==0.2.5
+ - kornia==0.6
+ - open_clip_torch==2.0.2
+ - invisible-watermark>=0.1.5
+ - streamlit-drawable-canvas==0.8.0
+ - torchmetrics==0.6.0
+ - triton
+ - matplotlib
+ - wandb
+ - pillow
diff --git a/StableSR/ldm/data/__init__.py b/StableSR/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StableSR/ldm/data/base.py b/StableSR/ldm/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b196c2f7aa583a3e8bc4aad9f943df0c4dae0da7
--- /dev/null
+++ b/StableSR/ldm/data/base.py
@@ -0,0 +1,23 @@
+from abc import abstractmethod
+from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
+
+
+class Txt2ImgIterableBaseDataset(IterableDataset):
+ '''
+ Define an interface to make the IterableDatasets for text2img data chainable
+ '''
+ def __init__(self, num_records=0, valid_ids=None, size=256):
+ super().__init__()
+ self.num_records = num_records
+ self.valid_ids = valid_ids
+ self.sample_ids = valid_ids
+ self.size = size
+
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
+
+ def __len__(self):
+ return self.num_records
+
+ @abstractmethod
+ def __iter__(self):
+ pass
\ No newline at end of file
diff --git a/StableSR/ldm/data/imagenet.py b/StableSR/ldm/data/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c473f9c6965b22315dbb289eff8247c71bdc790
--- /dev/null
+++ b/StableSR/ldm/data/imagenet.py
@@ -0,0 +1,394 @@
+import os, yaml, pickle, shutil, tarfile, glob
+import cv2
+import albumentations
+import PIL
+import numpy as np
+import torchvision.transforms.functional as TF
+from omegaconf import OmegaConf
+from functools import partial
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset, Subset
+
+import taming.data.utils as tdu
+from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
+from taming.data.imagenet import ImagePaths
+
+from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
+
+
+def synset2idx(path_to_yaml="data/index_synset.yaml"):
+ with open(path_to_yaml) as f:
+ di2s = yaml.load(f)
+ return dict((v,k) for k,v in di2s.items())
+
+
+class ImageNetBase(Dataset):
+ def __init__(self, config=None):
+ self.config = config or OmegaConf.create()
+ if not type(self.config)==dict:
+ self.config = OmegaConf.to_container(self.config)
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
+ self._prepare()
+ self._prepare_synset_to_human()
+ self._prepare_idx_to_synset()
+ self._prepare_human_to_integer_label()
+ self._load()
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def _prepare(self):
+ raise NotImplementedError()
+
+ def _filter_relpaths(self, relpaths):
+ ignore = set([
+ "n06596364_9591.JPEG",
+ ])
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
+ if "sub_indices" in self.config:
+ indices = str_to_indices(self.config["sub_indices"])
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
+ files = []
+ for rpath in relpaths:
+ syn = rpath.split("/")[0]
+ if syn in synsets:
+ files.append(rpath)
+ return files
+ else:
+ return relpaths
+
+ def _prepare_synset_to_human(self):
+ SIZE = 2655750
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
+ if (not os.path.exists(self.human_dict) or
+ not os.path.getsize(self.human_dict)==SIZE):
+ download(URL, self.human_dict)
+
+ def _prepare_idx_to_synset(self):
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
+ if (not os.path.exists(self.idx2syn)):
+ download(URL, self.idx2syn)
+
+ def _prepare_human_to_integer_label(self):
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
+ if (not os.path.exists(self.human2integer)):
+ download(URL, self.human2integer)
+ with open(self.human2integer, "r") as f:
+ lines = f.read().splitlines()
+ assert len(lines) == 1000
+ self.human2integer_dict = dict()
+ for line in lines:
+ value, key = line.split(":")
+ self.human2integer_dict[key] = int(value)
+
+ def _load(self):
+ with open(self.txt_filelist, "r") as f:
+ self.relpaths = f.read().splitlines()
+ l1 = len(self.relpaths)
+ self.relpaths = self._filter_relpaths(self.relpaths)
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
+
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
+
+ unique_synsets = np.unique(self.synsets)
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
+ if not self.keep_orig_class_label:
+ self.class_labels = [class_dict[s] for s in self.synsets]
+ else:
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
+
+ with open(self.human_dict, "r") as f:
+ human_dict = f.read().splitlines()
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
+
+ self.human_labels = [human_dict[s] for s in self.synsets]
+
+ labels = {
+ "relpath": np.array(self.relpaths),
+ "synsets": np.array(self.synsets),
+ "class_label": np.array(self.class_labels),
+ "human_label": np.array(self.human_labels),
+ }
+
+ if self.process_images:
+ self.size = retrieve(self.config, "size", default=256)
+ self.data = ImagePaths(self.abspaths,
+ labels=labels,
+ size=self.size,
+ random_crop=self.random_crop,
+ )
+ else:
+ self.data = self.abspaths
+
+
+class ImageNetTrain(ImageNetBase):
+ NAME = "ILSVRC2012_train"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
+ FILES = [
+ "ILSVRC2012_img_train.tar",
+ ]
+ SIZES = [
+ 147897477120,
+ ]
+
+ def __init__(self, process_images=True, data_root=None, **kwargs):
+ self.process_images = process_images
+ self.data_root = data_root
+ super().__init__(**kwargs)
+
+ def _prepare(self):
+ if self.data_root:
+ self.root = os.path.join(self.data_root, self.NAME)
+ else:
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 1281167
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
+ default=True)
+ if not tdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ print("Extracting sub-tars.")
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
+ for subpath in tqdm(subpaths):
+ subdir = subpath[:-len(".tar")]
+ os.makedirs(subdir, exist_ok=True)
+ with tarfile.open(subpath, "r:") as tar:
+ tar.extractall(path=subdir)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ tdu.mark_prepared(self.root)
+
+
+class ImageNetValidation(ImageNetBase):
+ NAME = "ILSVRC2012_validation"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
+ FILES = [
+ "ILSVRC2012_img_val.tar",
+ "validation_synset.txt",
+ ]
+ SIZES = [
+ 6744924160,
+ 1950000,
+ ]
+
+ def __init__(self, process_images=True, data_root=None, **kwargs):
+ self.data_root = data_root
+ self.process_images = process_images
+ super().__init__(**kwargs)
+
+ def _prepare(self):
+ if self.data_root:
+ self.root = os.path.join(self.data_root, self.NAME)
+ else:
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 50000
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
+ default=False)
+ if not tdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ vspath = os.path.join(self.root, self.FILES[1])
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ download(self.VS_URL, vspath)
+
+ with open(vspath, "r") as f:
+ synset_dict = f.read().splitlines()
+ synset_dict = dict(line.split() for line in synset_dict)
+
+ print("Reorganizing into synset folders")
+ synsets = np.unique(list(synset_dict.values()))
+ for s in synsets:
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
+ for k, v in synset_dict.items():
+ src = os.path.join(datadir, k)
+ dst = os.path.join(datadir, v)
+ shutil.move(src, dst)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ tdu.mark_prepared(self.root)
+
+
+
+class ImageNetSR(Dataset):
+ def __init__(self, size=None,
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
+ random_crop=True):
+ """
+ Imagenet Superresolution Dataloader
+ Performs following ops in order:
+ 1. crops a crop of size s from image either as random or center crop
+ 2. resizes crop to size with cv2.area_interpolation
+ 3. degrades resized crop with degradation_fn
+
+ :param size: resizing to size after cropping
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
+ :param downscale_f: Low Resolution Downsample factor
+ :param min_crop_f: determines crop size s,
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
+ :param max_crop_f: ""
+ :param data_root:
+ :param random_crop:
+ """
+ self.base = self.get_base()
+ assert size
+ assert (size / downscale_f).is_integer()
+ self.size = size
+ self.LR_size = int(size / downscale_f)
+ self.min_crop_f = min_crop_f
+ self.max_crop_f = max_crop_f
+ assert(max_crop_f <= 1.)
+ self.center_crop = not random_crop
+
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
+
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
+
+ if degradation == "bsrgan":
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
+
+ elif degradation == "bsrgan_light":
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
+
+ else:
+ interpolation_fn = {
+ "cv_nearest": cv2.INTER_NEAREST,
+ "cv_bilinear": cv2.INTER_LINEAR,
+ "cv_bicubic": cv2.INTER_CUBIC,
+ "cv_area": cv2.INTER_AREA,
+ "cv_lanczos": cv2.INTER_LANCZOS4,
+ "pil_nearest": PIL.Image.NEAREST,
+ "pil_bilinear": PIL.Image.BILINEAR,
+ "pil_bicubic": PIL.Image.BICUBIC,
+ "pil_box": PIL.Image.BOX,
+ "pil_hamming": PIL.Image.HAMMING,
+ "pil_lanczos": PIL.Image.LANCZOS,
+ }[degradation]
+
+ self.pil_interpolation = degradation.startswith("pil_")
+
+ if self.pil_interpolation:
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
+
+ else:
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
+ interpolation=interpolation_fn)
+
+ def __len__(self):
+ return len(self.base)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = Image.open(example["file_path_"])
+
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ image = np.array(image).astype(np.uint8)
+
+ min_side_len = min(image.shape[:2])
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
+ crop_side_len = int(crop_side_len)
+
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
+
+ else:
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
+
+ image = self.cropper(image=image)["image"]
+ image = self.image_rescaler(image=image)["image"]
+
+ if self.pil_interpolation:
+ image_pil = PIL.Image.fromarray(image)
+ LR_image = self.degradation_process(image_pil)
+ LR_image = np.array(LR_image).astype(np.uint8)
+
+ else:
+ LR_image = self.degradation_process(image=image)["image"]
+
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
+
+ return example
+
+
+class ImageNetSRTrain(ImageNetSR):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_base(self):
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
+ indices = pickle.load(f)
+ dset = ImageNetTrain(process_images=False,)
+ return Subset(dset, indices)
+
+
+class ImageNetSRValidation(ImageNetSR):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def get_base(self):
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
+ indices = pickle.load(f)
+ dset = ImageNetValidation(process_images=False,)
+ return Subset(dset, indices)
diff --git a/StableSR/ldm/data/lsun.py b/StableSR/ldm/data/lsun.py
new file mode 100644
index 0000000000000000000000000000000000000000..6256e45715ff0b57c53f985594d27cbbbff0e68e
--- /dev/null
+++ b/StableSR/ldm/data/lsun.py
@@ -0,0 +1,92 @@
+import os
+import numpy as np
+import PIL
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+
+class LSUNBase(Dataset):
+ def __init__(self,
+ txt_file,
+ data_root,
+ size=None,
+ interpolation="bicubic",
+ flip_p=0.5
+ ):
+ self.data_paths = txt_file
+ self.data_root = data_root
+ with open(self.data_paths, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, l)
+ for l in self.image_paths],
+ }
+
+ self.size = size
+ self.interpolation = {"linear": PIL.Image.LINEAR,
+ "bilinear": PIL.Image.BILINEAR,
+ "bicubic": PIL.Image.BICUBIC,
+ "lanczos": PIL.Image.LANCZOS,
+ }[interpolation]
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+
+ # default to score-sde preprocessing
+ img = np.array(image).astype(np.uint8)
+ crop = min(img.shape[0], img.shape[1])
+ h, w, = img.shape[0], img.shape[1]
+ img = img[(h - crop) // 2:(h + crop) // 2,
+ (w - crop) // 2:(w + crop) // 2]
+
+ image = Image.fromarray(img)
+ if self.size is not None:
+ image = image.resize((self.size, self.size), resample=self.interpolation)
+
+ image = self.flip(image)
+ image = np.array(image).astype(np.uint8)
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
+ return example
+
+
+class LSUNChurchesTrain(LSUNBase):
+ def __init__(self, **kwargs):
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
+
+
+class LSUNChurchesValidation(LSUNBase):
+ def __init__(self, flip_p=0., **kwargs):
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
+ flip_p=flip_p, **kwargs)
+
+
+class LSUNBedroomsTrain(LSUNBase):
+ def __init__(self, **kwargs):
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
+
+
+class LSUNBedroomsValidation(LSUNBase):
+ def __init__(self, flip_p=0.0, **kwargs):
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
+ flip_p=flip_p, **kwargs)
+
+
+class LSUNCatsTrain(LSUNBase):
+ def __init__(self, **kwargs):
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
+
+
+class LSUNCatsValidation(LSUNBase):
+ def __init__(self, flip_p=0., **kwargs):
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
+ flip_p=flip_p, **kwargs)
diff --git a/StableSR/ldm/lr_scheduler.py b/StableSR/ldm/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade
--- /dev/null
+++ b/StableSR/ldm/lr_scheduler.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n,**kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi))
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
+
diff --git a/StableSR/ldm/models/autoencoder.py b/StableSR/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b4156448b61788681c7bcdcdc9123a89a732ec8
--- /dev/null
+++ b/StableSR/ldm/models/autoencoder.py
@@ -0,0 +1,919 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder, Decoder_Mix
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.data.transforms import paired_random_crop, triplet_random_crop
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt
+import random
+
+import torchvision.transforms as transforms
+
+
+class VQModel(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ batch_resize_range=None,
+ scheduler_config=None,
+ lr_g_factor=1.0,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ use_ema=False
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.n_embed = n_embed
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+ remap=remap,
+ sane_index_shape=sane_index_shape)
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ self.batch_resize_range = batch_resize_range
+ if self.batch_resize_range is not None:
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.scheduler_config = scheduler_config
+ self.lr_g_factor = lr_g_factor
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ print(f"Unexpected Keys: {unexpected}")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def encode_to_prequant(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input, return_pred_indices=False):
+ quant, diff, (_,_,ind) = self.encode(input)
+ dec = self.decode(quant)
+ if return_pred_indices:
+ return dec, diff, ind
+ return dec, diff
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ if self.batch_resize_range is not None:
+ lower_size = self.batch_resize_range[0]
+ upper_size = self.batch_resize_range[1]
+ if self.global_step <= 4:
+ # do the first few batches with max size to avoid later oom
+ new_resize = upper_size
+ else:
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
+ if new_resize != x.shape[2]:
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
+ x = x.detach()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ # https://github.com/pytorch/pytorch/issues/37142
+ # try not to fool the heuristics
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train",
+ predicted_indices=ind)
+
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, suffix=""):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss, ind = self(x, return_pred_indices=True)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val"+suffix,
+ predicted_indices=ind
+ )
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log(f"val{suffix}/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log(f"val{suffix}/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
+ del log_dict_ae[f"val{suffix}/rec_loss"]
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr_d = self.learning_rate
+ lr_g = self.lr_g_factor*self.learning_rate
+ print("lr_d", lr_d)
+ print("lr_g", lr_g)
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr_g, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr_d, betas=(0.5, 0.9))
+
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ {
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ },
+ ]
+ return [opt_ae, opt_disc], scheduler
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if only_inputs:
+ log["inputs"] = x
+ return log
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ if plot_ema:
+ with self.ema_scope():
+ xrec_ema, _ = self(x)
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
+ log["reconstructions_ema"] = xrec_ema
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+class VQModelInterface(VQModel):
+ def __init__(self, embed_dim, *args, **kwargs):
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
+ self.embed_dim = embed_dim
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode(self, h, force_not_quantize=False):
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ if 'first_stage_model' in k:
+ sd[k[18:]] = sd[k]
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Encoder Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ # if len(unexpected) > 0:
+ # print(f"Unexpected Keys: {unexpected}")
+
+ def encode(self, x, return_encfea=False):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ if return_encfea:
+ return posterior, moments
+ return posterior
+
+ def encode_gt(self, x, new_encoder):
+ h = new_encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior, moments
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ # x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ x = x.to(memory_format=torch.contiguous_format).float()
+ # x = x*2.0-1.0
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ # log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
+class AutoencoderKLResi(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ fusion_w=1.0,
+ freeze_dec=True,
+ synthesis_data=False,
+ use_usm=False,
+ test_gt=False,
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder_Mix(**ddconfig)
+ self.decoder.fusion_w = fusion_w
+ self.loss = instantiate_from_config(lossconfig)
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ missing_list = self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ else:
+ missing_list = []
+
+ print('>>>>>>>>>>>>>>>>>missing>>>>>>>>>>>>>>>>>>>')
+ print(missing_list)
+ self.synthesis_data = synthesis_data
+ self.use_usm = use_usm
+ self.test_gt = test_gt
+
+ if freeze_dec:
+ for name, param in self.named_parameters():
+ if 'fusion_layer' in name:
+ param.requires_grad = True
+ # elif 'encoder' in name:
+ # param.requires_grad = True
+ # elif 'quant_conv' in name and 'post_quant_conv' not in name:
+ # param.requires_grad = True
+ elif 'loss.discriminator' in name:
+ param.requires_grad = True
+ else:
+ param.requires_grad = False
+
+ print('>>>>>>>>>>>>>>>>>trainable_list>>>>>>>>>>>>>>>>>>>')
+ trainable_list = []
+ for name, params in self.named_parameters():
+ if params.requires_grad:
+ trainable_list.append(name)
+ print(trainable_list)
+
+ print('>>>>>>>>>>>>>>>>>Untrainable_list>>>>>>>>>>>>>>>>>>>')
+ untrainable_list = []
+ for name, params in self.named_parameters():
+ if not params.requires_grad:
+ untrainable_list.append(name)
+ print(untrainable_list)
+ # untrainable_list = list(set(trainable_list).difference(set(missing_list)))
+ # print('>>>>>>>>>>>>>>>>>untrainable_list>>>>>>>>>>>>>>>>>>>')
+ # print(untrainable_list)
+
+ # def init_from_ckpt(self, path, ignore_keys=list()):
+ # sd = torch.load(path, map_location="cpu")["state_dict"]
+ # keys = list(sd.keys())
+ # for k in keys:
+ # for ik in ignore_keys:
+ # if k.startswith(ik):
+ # print("Deleting key {} from state_dict.".format(k))
+ # del sd[k]
+ # self.load_state_dict(sd, strict=False)
+ # print(f"Restored from {path}")
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ if 'first_stage_model' in k:
+ sd[k[18:]] = sd[k]
+ del sd[k]
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Encoder Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+ return missing
+
+ def encode(self, x):
+ h, enc_fea = self.encoder(x, return_fea=True)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ # posterior = h
+ return posterior, enc_fea
+
+ def encode_gt(self, x, new_encoder):
+ h = new_encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior, moments
+
+ def decode(self, z, enc_fea):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z, enc_fea)
+ return dec
+
+ def forward(self, input, latent, sample_posterior=True):
+ posterior, enc_fea_lq = self.encode(input)
+ dec = self.decode(latent, enc_fea_lq)
+ return dec, posterior
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self):
+ """It is the training pair pool for increasing the diversity in a batch.
+
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+ to increase the degradation diversity in a batch.
+ """
+ # initialize
+ b, c, h, w = self.lq.size()
+ _, c_, h_, w_ = self.latent.size()
+ if b == self.configs.data.params.batch_size:
+ if not hasattr(self, 'queue_size'):
+ self.queue_size = self.configs.data.params.train.params.get('queue_size', b*50)
+ if not hasattr(self, 'queue_lr'):
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+ _, c, h, w = self.gt.size()
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+ self.queue_sample = torch.zeros(self.queue_size, c, h, w).cuda()
+ self.queue_latent = torch.zeros(self.queue_size, c_, h_, w_).cuda()
+ self.queue_ptr = 0
+ if self.queue_ptr == self.queue_size: # the pool is full
+ # do dequeue and enqueue
+ # shuffle
+ idx = torch.randperm(self.queue_size)
+ self.queue_lr = self.queue_lr[idx]
+ self.queue_gt = self.queue_gt[idx]
+ self.queue_sample = self.queue_sample[idx]
+ self.queue_latent = self.queue_latent[idx]
+ # get first b samples
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+ sample_dequeue = self.queue_sample[0:b, :, :, :].clone()
+ latent_dequeue = self.queue_latent[0:b, :, :, :].clone()
+ # update the queue
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
+ self.queue_sample[0:b, :, :, :] = self.sample.clone()
+ self.queue_latent[0:b, :, :, :] = self.latent.clone()
+
+ self.lq = lq_dequeue
+ self.gt = gt_dequeue
+ self.sample = sample_dequeue
+ self.latent = latent_dequeue
+ else:
+ # only do enqueue
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_sample[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.sample.clone()
+ self.queue_latent[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.latent.clone()
+ self.queue_ptr = self.queue_ptr + b
+
+ def get_input(self, batch):
+ input = batch['lq']
+ gt = batch['gt']
+ latent = batch['latent']
+ sample = batch['sample']
+
+ assert not torch.isnan(latent).any()
+
+ input = input.to(memory_format=torch.contiguous_format).float()
+ gt = gt.to(memory_format=torch.contiguous_format).float()
+ latent = latent.to(memory_format=torch.contiguous_format).float() / 0.18215
+
+ gt = gt * 2.0 - 1.0
+ input = input * 2.0 - 1.0
+ sample = sample * 2.0 -1.0
+
+ return input, gt, latent, sample
+
+ @torch.no_grad()
+ def get_input_synthesis(self, batch, val=False, test_gt=False):
+
+ jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
+ im_gt = batch['gt'].cuda()
+ if self.use_usm:
+ usm_sharpener = USMSharp().cuda() # do usm sharpening
+ im_gt = usm_sharpener(im_gt)
+ im_gt = im_gt.to(memory_format=torch.contiguous_format).float()
+ kernel1 = batch['kernel1'].cuda()
+ kernel2 = batch['kernel2'].cuda()
+ sinc_kernel = batch['sinc_kernel'].cuda()
+
+ ori_h, ori_w = im_gt.size()[2:4]
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(im_gt, kernel1)
+ # random resize
+ updown_type = random.choices(
+ ['up', 'down', 'keep'],
+ self.configs.degradation['resize_prob'],
+ )[0]
+ if updown_type == 'up':
+ scale = random.uniform(1, self.configs.degradation['resize_range'][1])
+ elif updown_type == 'down':
+ scale = random.uniform(self.configs.degradation['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
+ # add noise
+ gray_noise_prob = self.configs.degradation['gray_noise_prob']
+ if random.random() < self.configs.degradation['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out,
+ sigma_range=self.configs.degradation['noise_range'],
+ clip=True,
+ rounds=False,
+ gray_prob=gray_noise_prob,
+ )
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.configs.degradation['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range'])
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+ out = jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if random.random() < self.configs.degradation['second_blur_prob']:
+ out = filter2D(out, kernel2)
+ # random resize
+ updown_type = random.choices(
+ ['up', 'down', 'keep'],
+ self.configs.degradation['resize_prob2'],
+ )[0]
+ if updown_type == 'up':
+ scale = random.uniform(1, self.configs.degradation['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = random.uniform(self.configs.degradation['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out,
+ size=(int(ori_h / self.configs.sf * scale),
+ int(ori_w / self.configs.sf * scale)),
+ mode=mode,
+ )
+ # add noise
+ gray_noise_prob = self.configs.degradation['gray_noise_prob2']
+ if random.random() < self.configs.degradation['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out,
+ sigma_range=self.configs.degradation['noise_range2'],
+ clip=True,
+ rounds=False,
+ gray_prob=gray_noise_prob,
+ )
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.configs.degradation['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False,
+ )
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if random.random() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out,
+ size=(ori_h // self.configs.sf,
+ ori_w // self.configs.sf),
+ mode=mode,
+ )
+ out = filter2D(out, sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out,
+ size=(ori_h // self.configs.sf,
+ ori_w // self.configs.sf),
+ mode=mode,
+ )
+ out = filter2D(out, sinc_kernel)
+
+ # clamp and round
+ im_lq = torch.clamp(out, 0, 1.0)
+
+ # random crop
+ gt_size = self.configs.degradation['gt_size']
+ im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, self.configs.sf)
+ self.lq, self.gt = im_lq, im_gt
+
+ self.lq = F.interpolate(
+ self.lq,
+ size=(self.gt.size(-2),
+ self.gt.size(-1)),
+ mode='bicubic',
+ )
+
+ self.latent = batch['latent'] / 0.18215
+ self.sample = batch['sample'] * 2 - 1.0
+ # training pair pool
+ if not val:
+ self._dequeue_and_enqueue()
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
+ self.lq = self.lq*2 - 1.0
+ self.gt = self.gt*2 - 1.0
+
+ self.lq = torch.clamp(self.lq, -1.0, 1.0)
+
+ x = self.lq
+ y = self.gt
+ x = x.to(self.device)
+ y = y.to(self.device)
+
+ if self.test_gt:
+ return y, y, self.latent.to(self.device), self.sample.to(self.device)
+ else:
+ return x, y, self.latent.to(self.device), self.sample.to(self.device)
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ if self.synthesis_data:
+ inputs, gts, latents, _ = self.get_input_synthesis(batch, val=False)
+ else:
+ inputs, gts, latents, _ = self.get_input(batch)
+ reconstructions, posterior = self(inputs, latents)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(gts, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(gts, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs, gts, latents, _ = self.get_input(batch)
+
+ reconstructions, posterior = self(inputs, latents)
+ aeloss, log_dict_ae = self.loss(gts, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(gts, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ # list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, **kwargs):
+ log = dict()
+ if self.synthesis_data:
+ x, gts, latents, samples = self.get_input_synthesis(batch, val=False)
+ else:
+ x, gts, latents, samples = self.get_input(batch)
+ x = x.to(self.device)
+ latents = latents.to(self.device)
+ samples = samples.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x, latents)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ gts = self.to_rgb(gts)
+ samples = self.to_rgb(samples)
+ xrec = self.to_rgb(xrec)
+ # log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ log["inputs"] = x
+ log["gts"] = gts
+ log["samples"] = samples
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
diff --git a/StableSR/ldm/models/diffusion/__init__.py b/StableSR/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StableSR/ldm/models/diffusion/classifier.py b/StableSR/ldm/models/diffusion/classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e98b9d8ffb96a150b517497ace0a242d7163ef
--- /dev/null
+++ b/StableSR/ldm/models/diffusion/classifier.py
@@ -0,0 +1,267 @@
+import os
+import torch
+import pytorch_lightning as pl
+from omegaconf import OmegaConf
+from torch.nn import functional as F
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import LambdaLR
+from copy import deepcopy
+from einops import rearrange
+from glob import glob
+from natsort import natsorted
+
+from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
+from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
+
+__models__ = {
+ 'class_label': EncoderUNetModel,
+ 'segmentation': UNetModel
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class NoisyLatentImageClassifier(pl.LightningModule):
+
+ def __init__(self,
+ diffusion_path,
+ num_classes,
+ ckpt_path=None,
+ pool='attention',
+ label_key=None,
+ diffusion_ckpt_path=None,
+ scheduler_config=None,
+ weight_decay=1.e-2,
+ log_steps=10,
+ monitor='val/loss',
+ *args,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ self.num_classes = num_classes
+ # get latest config of diffusion model
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
+ self.load_diffusion()
+
+ self.monitor = monitor
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
+ self.log_steps = log_steps
+
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
+ else self.diffusion_model.cond_stage_key
+
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
+
+ if self.label_key not in __models__:
+ raise NotImplementedError()
+
+ self.load_classifier(ckpt_path, pool)
+
+ self.scheduler_config = scheduler_config
+ self.use_scheduler = self.scheduler_config is not None
+ self.weight_decay = weight_decay
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def load_diffusion(self):
+ model = instantiate_from_config(self.diffusion_config)
+ self.diffusion_model = model.eval()
+ self.diffusion_model.train = disabled_train
+ for param in self.diffusion_model.parameters():
+ param.requires_grad = False
+
+ def load_classifier(self, ckpt_path, pool):
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
+ model_config.out_channels = self.num_classes
+ if self.label_key == 'class_label':
+ model_config.pool = pool
+
+ self.model = __models__[self.label_key](**model_config)
+ if ckpt_path is not None:
+ print('#####################################################################')
+ print(f'load from ckpt "{ckpt_path}"')
+ print('#####################################################################')
+ self.init_from_ckpt(ckpt_path)
+
+ @torch.no_grad()
+ def get_x_noisy(self, x, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x))
+ continuous_sqrt_alpha_cumprod = None
+ if self.diffusion_model.use_continuous_noise:
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
+ # todo: make sure t+1 is correct here
+
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
+
+ def forward(self, x_noisy, t, *args, **kwargs):
+ return self.model(x_noisy, t)
+
+ @torch.no_grad()
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ @torch.no_grad()
+ def get_conditioning(self, batch, k=None):
+ if k is None:
+ k = self.label_key
+ assert k is not None, 'Needs to provide label key'
+
+ targets = batch[k].to(self.device)
+
+ if self.label_key == 'segmentation':
+ targets = rearrange(targets, 'b h w c -> b c h w')
+ for down in range(self.numd):
+ h, w = targets.shape[-2:]
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
+
+ # targets = rearrange(targets,'b c h w -> b h w c')
+
+ return targets
+
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
+ _, top_ks = torch.topk(logits, k, dim=1)
+ if reduction == "mean":
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
+ elif reduction == "none":
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
+
+ def on_train_epoch_start(self):
+ # save some memory
+ self.diffusion_model.model.to('cpu')
+
+ @torch.no_grad()
+ def write_logs(self, loss, logits, targets):
+ log_prefix = 'train' if self.training else 'val'
+ log = {}
+ log[f"{log_prefix}/loss"] = loss.mean()
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
+ logits, targets, k=1, reduction="mean"
+ )
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
+ logits, targets, k=5, reduction="mean"
+ )
+
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
+
+ def shared_step(self, batch, t=None):
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
+ targets = self.get_conditioning(batch)
+ if targets.dim() == 4:
+ targets = targets.argmax(dim=1)
+ if t is None:
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
+ else:
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
+ x_noisy = self.get_x_noisy(x, t)
+ logits = self(x_noisy, t)
+
+ loss = F.cross_entropy(logits, targets, reduction='none')
+
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
+
+ loss = loss.mean()
+ return loss, logits, x_noisy, targets
+
+ def training_step(self, batch, batch_idx):
+ loss, *_ = self.shared_step(batch)
+ return loss
+
+ def reset_noise_accs(self):
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
+
+ def on_validation_start(self):
+ self.reset_noise_accs()
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ loss, *_ = self.shared_step(batch)
+
+ for t in self.noisy_acc:
+ _, logits, _, targets = self.shared_step(batch, t)
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
+
+ return loss
+
+ def configure_optimizers(self):
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
+
+ if self.use_scheduler:
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [optimizer], scheduler
+
+ return optimizer
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, *args, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
+ log['inputs'] = x
+
+ y = self.get_conditioning(batch)
+
+ if self.label_key == 'class_label':
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+ log['labels'] = y
+
+ if ismap(y):
+ log['labels'] = self.diffusion_model.to_rgb(y)
+
+ for step in range(self.log_steps):
+ current_time = step * self.log_time_interval
+
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
+
+ log[f'inputs@t{current_time}'] = x_noisy
+
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
+ pred = rearrange(pred, 'b h w c -> b c h w')
+
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
+
+ for key in log:
+ log[key] = log[key][:N]
+
+ return log
diff --git a/StableSR/ldm/models/diffusion/ddim.py b/StableSR/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..411257c9184e334aae4f2da9c0bfea452884893e
--- /dev/null
+++ b/StableSR/ldm/models/diffusion/ddim.py
@@ -0,0 +1,675 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
+ extract_into_tensor
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim"):])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")] #[250,]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def q_sample(self, x_start, t, noise=None, ddim_num_steps=200):
+ self.make_schedule(ddim_num_steps=ddim_num_steps)
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ return x_dec
+
+
+ @torch.no_grad()
+ def p_sample_ddim_sr(self, x, c, struct_c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c, struct_c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, struct_c).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def decode_sr(self, x_latent, cond, struct_cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim_sr(x_dec, cond, struct_cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ return x_dec
+
+ @torch.no_grad()
+ def sample_sr(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ struct_cond=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ _, C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling_sr(conditioning, struct_cond, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling_sr(self, cond, struct_cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_ddim_sr(img, cond, struct_cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim_sr(self, x, c, struct_c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c, struct_c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, struct_c).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+
+ @torch.no_grad()
+ def sample_sr_t(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ struct_cond=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ _, C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling_sr_t(conditioning, struct_cond, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling_sr_t(self, cond, struct_cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ # timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else sorted(set(space_timesteps(1000, [self.ddim_timesteps.shape[0]])))
+ timesteps = np.array(timesteps)
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_ddim_sr_t(img, cond, struct_cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim_sr_t(self, x, c, struct_c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ struct_c_t = self.model.structcond_stage_model(struct_c, t)
+ e_t = self.model.apply_model(x, t, c, struct_c_t)
+ else:
+ assert NotImplementedError
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, struct_c).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
diff --git a/StableSR/ldm/models/diffusion/ddpm.py b/StableSR/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a0c83d9904e447bfe058c22e39a292509f7020d
--- /dev/null
+++ b/StableSR/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,3234 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager
+from functools import partial
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+from basicsr.utils import DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.data.transforms import paired_random_crop, triplet_random_crop
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt, random_add_speckle_noise_pt, random_add_saltpepper_noise_pt, bivariate_Gaussian
+import random
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import make_ddim_timesteps
+import copy
+import os
+import cv2
+import matplotlib.pyplot as plt
+from sklearn.decomposition import PCA
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+def torch2img(input):
+ input_ = input[0]
+ input_ = input_.permute(1,2,0)
+ input_ = input_.data.cpu().numpy()
+ input_ = (input_ + 1.0) / 2
+ cv2.imwrite('./test.png', input_[:,:,::-1]*255.0)
+
+def cal_pca_components(input, n_components=3):
+ pca = PCA(n_components=n_components)
+ c, h, w = input.size()
+ pca_data = input.permute(1,2,0)
+ pca_data = pca_data.reshape(h*w, c)
+ pca_data = pca.fit_transform(pca_data.data.cpu().numpy())
+ pca_data = pca_data.reshape((h, w, n_components))
+ return pca_data
+
+def visualize_fea(save_path, fea_img):
+ fig = plt.figure(figsize = (fea_img.shape[1]/10, fea_img.shape[0]/10)) # Your image (W)idth and (H)eight in inches
+ plt.subplots_adjust(left = 0, right = 1.0, top = 1.0, bottom = 0)
+ im = plt.imshow(fea_img, vmin=0.0, vmax=1.0, cmap='jet', aspect='auto') # Show the image
+ plt.savefig(save_path)
+ plt.clf()
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim"):])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")] #[250,]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ # TODO how to choose this term
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print('<<<<<<<<<<<<>>>>>>>>>>>>>>>')
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ elif self.parameterization == "v":
+ x_recon = self.predict_start_from_z_and_v(x, model_out, t)
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def q_sample_respace(self, x_start, t, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(sqrt_alphas_cumprod.to(noise.device), t, x_start.shape) * x_start +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod.to(noise.device), t, x_start.shape) * noise)
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def predict_start_from_z_and_v(self, x, v, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * v
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+class LatentDiffusion(DDPM):
+ """main class"""
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ *args, **kwargs):
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__':
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+ # self.model.eval()
+ # self.model.train = disabled_train
+ # for param in self.model.parameters():
+ # param.requires_grad = False
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None):
+ x = batch[k]
+
+ x = F.interpolate(
+ x,
+ size=(self.image_size,
+ self.image_size),
+ mode='bicubic',
+ )
+
+ if len(x.shape) == 3:
+ x = x[..., None]
+ # x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if self.model.conditioning_key is not None:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox']:
+ # xc = batch[cond_key]
+ xc = ['']*x.size(0)
+ elif cond_key == 'class_label':
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ # import pudb; pudb.set_trace()
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ # same as above but without decorator
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ df = self.split_input_params["vqf"]
+ self.split_input_params['original_image_size'] = x.shape[-2:]
+ bs, nc, h, w = x.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
+ z = unfold(x) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization
+ return decoded
+
+ else:
+ return self.first_stage_model.encode(x)
+ else:
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
+ def rescale_bbox(bbox):
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ return x0, y0, w, h
+
+ return [rescale_bbox(b) for b in bboxes]
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+
+ if isinstance(cond, dict):
+ # hybrid case, cond is exptected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ if hasattr(self, "split_input_params"):
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
+ assert not return_ids
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+
+ h, w = x_noisy.shape[-2:]
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
+
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
+
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
+ c_key = next(iter(cond.keys())) # get key
+ c = next(iter(cond.values())) # get value
+ assert (len(c) == 1) # todo extend to list with more than one elem
+ c = c[0] # get element
+
+ c = unfold(c)
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
+
+ elif self.cond_stage_key == 'coordinates_bbox':
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
+
+ # assuming padding of unfold is always 0 and its dilation is always 1
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
+ # as we are operating on latents, we need the factor from the original image size to the
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
+ rescale_latent = 2 ** (num_downs)
+
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
+ # need to rescale the tl patch coordinates to be in between (0,1)
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
+ for patch_nr in range(z.shape[-1])]
+
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
+ patch_limits = [(x_tl, y_tl,
+ rescale_latent * ks[0] / full_img_w,
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
+
+ # tokenize crop coordinates for the bounding boxes of the respective patches
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
+ print(patch_limits_tknzd[0].shape)
+ # cut tknzd crop position from conditioning
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
+ print(cut_cond.shape)
+
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
+ print(adapted_cond.shape)
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
+ print(adapted_cond.shape)
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
+ print(adapted_cond.shape)
+
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
+
+ else:
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
+
+ # apply model by loop over crops
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
+ assert not isinstance(output_list[0],
+ tuple) # todo cant deal with multiple model outputs check this never happens
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ x_recon = fold(o) / normalization
+
+ else:
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ elif self.parameterization == "v":
+ x_recon = self.predict_start_from_z_and_v(x, model_out, t)
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None,**kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size//8, self.image_size//8)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
+
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
+ shape,cond,verbose=False,**kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True,**kwargs)
+
+ return samples, intermediates
+
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, **kwargs):
+
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+
+ # print(z.size())
+ # print(x.size())
+ # if self.model.conditioning_key is not None:
+ # if hasattr(self.cond_stage_model, "decode"):
+ # xc = self.cond_stage_model.decode(c)
+ # log["conditioning"] = xc
+ # elif self.cond_stage_key in ["caption"]:
+ # xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
+ # log["conditioning"] = xc
+ # elif self.cond_stage_key == 'class_label':
+ # xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+ # log['conditioning'] = xc
+ # elif isimage(xc):
+ # log["conditioning"] = xc
+ # if ismap(xc):
+ # log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with self.ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with self.ema_scope("Plotting Inpaint"):
+
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ with self.ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with self.ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ # params = list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+class LatentDiffusionSRTextWT(DDPM):
+ """main class"""
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ structcond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ unfrozen_diff=False,
+ random_size=False,
+ test_gt=False,
+ p2_gamma=None,
+ p2_k=None,
+ time_replace=None,
+ use_usm=False,
+ mix_ratio=0.0,
+ *args, **kwargs):
+ # put this in your init
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ self.unfrozen_diff = unfrozen_diff
+ self.random_size = random_size
+ self.test_gt = test_gt
+ self.time_replace = time_replace
+ self.use_usm = use_usm
+ self.mix_ratio = mix_ratio
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__':
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.instantiate_structcond_stage(structcond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+ if not self.unfrozen_diff:
+ self.model.eval()
+ # self.model.train = disabled_train
+ for name, param in self.model.named_parameters():
+ if 'spade' not in name:
+ param.requires_grad = False
+ else:
+ param.requires_grad = True
+
+ print('>>>>>>>>>>>>>>>>model>>>>>>>>>>>>>>>>>>>>')
+ param_list = []
+ for name, params in self.model.named_parameters():
+ if params.requires_grad:
+ param_list.append(name)
+ print(param_list)
+ param_list = []
+ print('>>>>>>>>>>>>>>>>>cond_stage_model>>>>>>>>>>>>>>>>>>>')
+ for name, params in self.cond_stage_model.named_parameters():
+ if params.requires_grad:
+ param_list.append(name)
+ print(param_list)
+ param_list = []
+ print('>>>>>>>>>>>>>>>>structcond_stage_model>>>>>>>>>>>>>>>>>>>>')
+ for name, params in self.structcond_stage_model.named_parameters():
+ if params.requires_grad:
+ param_list.append(name)
+ print(param_list)
+
+ # P2 weighting: https://github.com/jychoi118/P2-weighting
+ if p2_gamma is not None:
+ assert p2_k is not None
+ self.p2_gamma = p2_gamma
+ self.p2_k = p2_k
+ self.snr = 1.0 / (1 - self.alphas_cumprod) - 1
+ else:
+ self.snr = None
+
+ # Support time respacing during training
+ if self.time_replace is None:
+ self.time_replace = kwargs['timesteps']
+ use_timesteps = set(space_timesteps(kwargs['timesteps'], [self.time_replace]))
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ timestep_map = []
+ for i, alpha_cumprod in enumerate(self.alphas_cumprod):
+ if i in use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ timestep_map.append(i)
+ new_betas = [beta.data.cpu().numpy() for beta in new_betas]
+ self.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas), linear_start=kwargs['linear_start'], linear_end=kwargs['linear_end'])
+ self.ori_timesteps = list(use_timesteps)
+ self.ori_timesteps.sort()
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ # self.cond_stage_model.train = disabled_train
+ for name, param in self.cond_stage_model.named_parameters():
+ if 'final_projector' not in name:
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+ self.cond_stage_model.train()
+
+ def instantiate_structcond_stage(self, config):
+ model = instantiate_from_config(config)
+ self.structcond_stage_model = model
+ self.structcond_stage_model.train()
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self):
+ """It is the training pair pool for increasing the diversity in a batch, taken from Real-ESRGAN:
+ https://github.com/xinntao/Real-ESRGAN
+
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
+ to increase the degradation diversity in a batch.
+ """
+ # initialize
+ b, c, h, w = self.lq.size()
+ if b == self.configs.data.params.batch_size:
+ if not hasattr(self, 'queue_size'):
+ self.queue_size = self.configs.data.params.train.params.get('queue_size', b*50)
+ if not hasattr(self, 'queue_lr'):
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
+ self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
+ _, c, h, w = self.gt.size()
+ self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
+ self.queue_ptr = 0
+ if self.queue_ptr == self.queue_size: # the pool is full
+ # do dequeue and enqueue
+ # shuffle
+ idx = torch.randperm(self.queue_size)
+ self.queue_lr = self.queue_lr[idx]
+ self.queue_gt = self.queue_gt[idx]
+ # get first b samples
+ lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
+ gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
+ # update the queue
+ self.queue_lr[0:b, :, :, :] = self.lq.clone()
+ self.queue_gt[0:b, :, :, :] = self.gt.clone()
+
+ self.lq = lq_dequeue
+ self.gt = gt_dequeue
+ else:
+ # only do enqueue
+ self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_ptr = self.queue_ptr + b
+
+ def randn_cropinput(self, lq, gt, base_size=[64, 128, 256, 512]):
+ cur_size_h = random.choice(base_size)
+ cur_size_w = random.choice(base_size)
+ init_h = lq.size(-2)//2
+ init_w = lq.size(-1)//2
+ lq = lq[:, :, init_h-cur_size_h//2:init_h+cur_size_h//2, init_w-cur_size_w//2:init_w+cur_size_w//2]
+ gt = gt[:, :, init_h-cur_size_h//2:init_h+cur_size_h//2, init_w-cur_size_w//2:init_w+cur_size_w//2]
+ assert lq.size(-1)>=64
+ assert lq.size(-2)>=64
+ return [lq, gt]
+
+ @torch.no_grad()
+ def get_input(self, batch, k=None, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None, val=False, text_cond=[''], return_gt=False, resize_lq=True):
+
+ """Degradation pipeline, modified from Real-ESRGAN:
+ https://github.com/xinntao/Real-ESRGAN
+ """
+
+ if not hasattr(self, 'jpeger'):
+ jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
+ if not hasattr(self, 'usm_sharpener'):
+ usm_sharpener = USMSharp().cuda() # do usm sharpening
+
+ im_gt = batch['gt'].cuda()
+ if self.use_usm:
+ im_gt = usm_sharpener(im_gt)
+ im_gt = im_gt.to(memory_format=torch.contiguous_format).float()
+ kernel1 = batch['kernel1'].cuda()
+ kernel2 = batch['kernel2'].cuda()
+ sinc_kernel = batch['sinc_kernel'].cuda()
+
+ ori_h, ori_w = im_gt.size()[2:4]
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(im_gt, kernel1)
+ # random resize
+ updown_type = random.choices(
+ ['up', 'down', 'keep'],
+ self.configs.degradation['resize_prob'],
+ )[0]
+ if updown_type == 'up':
+ scale = random.uniform(1, self.configs.degradation['resize_range'][1])
+ elif updown_type == 'down':
+ scale = random.uniform(self.configs.degradation['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
+ # add noise
+ gray_noise_prob = self.configs.degradation['gray_noise_prob']
+ if random.random() < self.configs.degradation['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out,
+ sigma_range=self.configs.degradation['noise_range'],
+ clip=True,
+ rounds=False,
+ gray_prob=gray_noise_prob,
+ )
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.configs.degradation['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range'])
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+ out = jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if random.random() < self.configs.degradation['second_blur_prob']:
+ out = filter2D(out, kernel2)
+ # random resize
+ updown_type = random.choices(
+ ['up', 'down', 'keep'],
+ self.configs.degradation['resize_prob2'],
+ )[0]
+ if updown_type == 'up':
+ scale = random.uniform(1, self.configs.degradation['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = random.uniform(self.configs.degradation['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out,
+ size=(int(ori_h / self.configs.sf * scale),
+ int(ori_w / self.configs.sf * scale)),
+ mode=mode,
+ )
+ # add noise
+ gray_noise_prob = self.configs.degradation['gray_noise_prob2']
+ if random.random() < self.configs.degradation['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out,
+ sigma_range=self.configs.degradation['noise_range2'],
+ clip=True,
+ rounds=False,
+ gray_prob=gray_noise_prob,
+ )
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.configs.degradation['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False,
+ )
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if random.random() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out,
+ size=(ori_h // self.configs.sf,
+ ori_w // self.configs.sf),
+ mode=mode,
+ )
+ out = filter2D(out, sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.configs.degradation['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = F.interpolate(
+ out,
+ size=(ori_h // self.configs.sf,
+ ori_w // self.configs.sf),
+ mode=mode,
+ )
+ out = filter2D(out, sinc_kernel)
+
+ # clamp and round
+ im_lq = torch.clamp(out, 0, 1.0)
+
+ # random crop
+ gt_size = self.configs.degradation['gt_size']
+ im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, self.configs.sf)
+ self.lq, self.gt = im_lq, im_gt
+
+ if resize_lq:
+ self.lq = F.interpolate(
+ self.lq,
+ size=(self.gt.size(-2),
+ self.gt.size(-1)),
+ mode='bicubic',
+ )
+
+ if random.random() < self.configs.degradation['no_degradation_prob'] or torch.isnan(self.lq).any():
+ self.lq = self.gt
+
+ # training pair pool
+ if not val and not self.random_size:
+ self._dequeue_and_enqueue()
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
+ self.lq = self.lq*2 - 1.0
+ self.gt = self.gt*2 - 1.0
+
+ if self.random_size:
+ self.lq, self.gt = self.randn_cropinput(self.lq, self.gt)
+
+ self.lq = torch.clamp(self.lq, -1.0, 1.0)
+
+ x = self.lq
+ y = self.gt
+ if bs is not None:
+ x = x[:bs]
+ y = y[:bs]
+ x = x.to(self.device)
+ y = y.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ encoder_posterior_y = self.encode_first_stage(y)
+ z_gt = self.get_first_stage_encoding(encoder_posterior_y).detach()
+
+ xc = None
+ if self.use_positional_encodings:
+ assert NotImplementedError
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+
+ while len(text_cond) < z.size(0):
+ text_cond.append(text_cond[-1])
+ if len(text_cond) > z.size(0):
+ text_cond = text_cond[:z.size(0)]
+ assert len(text_cond) == z.size(0)
+
+ out = [z, text_cond]
+ out.append(z_gt)
+
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z_gt)
+ out.extend([x, self.gt, xrec])
+ if return_original_cond:
+ out.append(xc)
+
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+
+ # same as above but without decorator
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ df = self.split_input_params["vqf"]
+ self.split_input_params['original_image_size'] = x.shape[-2:]
+ bs, nc, h, w = x.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
+ z = unfold(x) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization
+ return decoded
+
+ else:
+ return self.first_stage_model.encode(x)
+ else:
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c, gt = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c, gt)
+ return loss
+
+ def forward(self, x, c, gt, *args, **kwargs):
+ index = np.random.randint(0, self.num_timesteps, size=x.size(0))
+ t = torch.from_numpy(index)
+ t = t.to(self.device).long()
+
+ t_ori = torch.tensor([self.ori_timesteps[index_i] for index_i in index])
+ t_ori = t_ori.long().to(x.device)
+
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ else:
+ c = self.cond_stage_model(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ print(s)
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ if self.test_gt:
+ struc_c = self.structcond_stage_model(gt, t_ori)
+ else:
+ struc_c = self.structcond_stage_model(x, t_ori)
+ return self.p_losses(gt, c, struc_c, t, t_ori, x, *args, **kwargs)
+
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
+ def rescale_bbox(bbox):
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ return x0, y0, w, h
+
+ return [rescale_bbox(b) for b in bboxes]
+
+ def apply_model(self, x_noisy, t, cond, struct_cond, return_ids=False):
+
+ if isinstance(cond, dict):
+ # hybrid case, cond is exptected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ if hasattr(self, "split_input_params"):
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
+ assert not return_ids
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+
+ h, w = x_noisy.shape[-2:]
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
+
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
+
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
+ c_key = next(iter(cond.keys())) # get key
+ c = next(iter(cond.values())) # get value
+ assert (len(c) == 1) # todo extend to list with more than one elem
+ c = c[0] # get element
+
+ c = unfold(c)
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
+
+ elif self.cond_stage_key == 'coordinates_bbox':
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
+
+ # assuming padding of unfold is always 0 and its dilation is always 1
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
+ # as we are operating on latents, we need the factor from the original image size to the
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
+ rescale_latent = 2 ** (num_downs)
+
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
+ # need to rescale the tl patch coordinates to be in between (0,1)
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
+ for patch_nr in range(z.shape[-1])]
+
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
+ patch_limits = [(x_tl, y_tl,
+ rescale_latent * ks[0] / full_img_w,
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
+
+ # tokenize crop coordinates for the bounding boxes of the respective patches
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
+ print(patch_limits_tknzd[0].shape)
+ # cut tknzd crop position from conditioning
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
+ print(cut_cond.shape)
+
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
+ print(adapted_cond.shape)
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
+ print(adapted_cond.shape)
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
+ print(adapted_cond.shape)
+
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
+
+ else:
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
+
+ # apply model by loop over crops
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
+ assert not isinstance(output_list[0],
+ tuple) # todo cant deal with multiple model outputs check this never happens
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ x_recon = fold(o) / normalization
+
+ else:
+ cond['struct_cond'] = struct_cond
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, struct_cond, t, t_ori, z_gt, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+
+ if self.mix_ratio > 0:
+ if random.random() < self.mix_ratio:
+ noise_new = default(noise, lambda: torch.randn_like(x_start))
+ noise = noise_new * 0.5 + noise * 0.5
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+
+ model_output = self.apply_model(x_noisy, t_ori, cond, struct_cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ model_output_ = model_output
+
+ loss_simple = self.get_loss(model_output_, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ #P2 weighting
+ if self.snr is not None:
+ self.snr = self.snr.to(loss_simple.device)
+ weight = extract_into_tensor(1 / (self.p2_k + self.snr)**self.p2_gamma, t, target.shape)
+ loss_simple = weight * loss_simple
+
+ logvar_t = self.logvar[t.cpu()].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output_, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, struct_cond, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None, t_replace=None):
+ if t_replace is None:
+ t_in = t
+ else:
+ t_in = t_replace
+ model_out = self.apply_model(x, t_in, c, struct_cond, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ elif self.parameterization == "v":
+ x_recon = self.predict_start_from_z_and_v(x, model_out, t)
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ def p_mean_variance_canvas(self, x, c, struct_cond, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None, t_replace=None, tile_size=64, tile_overlap=32, batch_size=4, tile_weights=None):
+ """
+ Aggregation Sampling strategy for arbitrary-size image super-resolution
+ """
+ assert tile_weights is not None
+
+ if t_replace is None:
+ t_in = t
+ else:
+ t_in = t_replace
+
+ _, _, h, w = x.size()
+
+ grid_rows = 0
+ cur_x = 0
+ while cur_x < x.size(-1):
+ cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
+ grid_rows += 1
+
+ grid_cols = 0
+ cur_y = 0
+ while cur_y < x.size(-2):
+ cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
+ grid_cols += 1
+
+ input_list = []
+ cond_list = []
+ noise_preds = []
+ for row in range(grid_rows):
+ noise_preds_row = []
+ for col in range(grid_cols):
+ if col < grid_cols-1 or row < grid_rows-1:
+ # extract tile from input image
+ ofs_x = max(row * tile_size-tile_overlap * row, 0)
+ ofs_y = max(col * tile_size-tile_overlap * col, 0)
+ # input tile area on total image
+ if row == grid_rows-1:
+ ofs_x = w - tile_size
+ if col == grid_cols-1:
+ ofs_y = h - tile_size
+
+ input_start_x = ofs_x
+ input_end_x = ofs_x + tile_size
+ input_start_y = ofs_y
+ input_end_y = ofs_y + tile_size
+
+ # print('input_start_x', input_start_x)
+ # print('input_end_x', input_end_x)
+ # print('input_start_y', input_start_y)
+ # print('input_end_y', input_end_y)
+
+ # input tile dimensions
+ input_tile_width = input_end_x - input_start_x
+ input_tile_height = input_end_y - input_start_y
+ input_tile = x[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
+ input_list.append(input_tile)
+ cond_tile = struct_cond[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
+ cond_list.append(cond_tile)
+
+ if len(input_list) == batch_size or col == grid_cols-1:
+ input_list = torch.cat(input_list, dim=0)
+ cond_list = torch.cat(cond_list, dim=0)
+
+ struct_cond_input = self.structcond_stage_model(cond_list, t_in[:input_list.size(0)])
+ model_out = self.apply_model(input_list, t_in[:input_list.size(0)], c[:input_list.size(0)], struct_cond_input, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, input_list, t[:input_list.size(0)], c[:input_list.size(0)], **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ for sample_i in range(model_out.size(0)):
+ noise_preds_row.append(model_out[sample_i].unsqueeze(0))
+ input_list = []
+ cond_list = []
+
+ noise_preds.append(noise_preds_row)
+
+ # Stitch noise predictions for all tiles
+ noise_pred = torch.zeros(x.shape, device=x.device)
+ contributors = torch.zeros(x.shape, device=x.device)
+ # Add each tile contribution to overall latents
+ for row in range(grid_rows):
+ for col in range(grid_cols):
+ if col < grid_cols-1 or row < grid_rows-1:
+ # extract tile from input image
+ ofs_x = max(row * tile_size-tile_overlap * row, 0)
+ ofs_y = max(col * tile_size-tile_overlap * col, 0)
+ # input tile area on total image
+ if row == grid_rows-1:
+ ofs_x = w - tile_size
+ if col == grid_cols-1:
+ ofs_y = h - tile_size
+
+ input_start_x = ofs_x
+ input_end_x = ofs_x + tile_size
+ input_start_y = ofs_y
+ input_end_y = ofs_y + tile_size
+ # print(noise_preds[row][col].size())
+ # print(tile_weights.size())
+ # print(noise_pred.size())
+ noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row][col] * tile_weights
+ contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
+ # Average overlapping areas with more than 1 contributor
+ noise_pred /= contributors
+ # noise_pred /= torch.sqrt(contributors)
+ model_out = noise_pred
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t[:model_out.size(0)], noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ elif self.parameterization == "v":
+ x_recon = self.predict_start_from_z_and_v(x, model_out, t[:model_out.size(0)])
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t[:x_recon.size(0)])
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, struct_cond, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, t_replace=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, struct_cond=struct_cond, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, t_replace=t_replace)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_canvas(self, x, c, struct_cond, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, t_replace=None,
+ tile_size=64, tile_overlap=32, batch_size=4, tile_weights=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance_canvas(x=x, c=c, struct_cond=struct_cond, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, t_replace=t_replace,
+ tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, tile_weights=tile_weights)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t[:b] == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, struct_cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, struct_cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, struct_cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None, time_replace=None, adain_fea=None, interfea_path=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ batch_list = []
+ for i in iterator:
+ if time_replace is None or time_replace == 1000:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ t_replace=None
+ else:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ t_replace = repeat(torch.tensor([self.ori_timesteps[i]]), '1 -> b', b=img.size(0))
+ t_replace = t_replace.long().to(device)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ if t_replace is not None:
+ if start_T is not None:
+ if self.ori_timesteps[i] > start_T:
+ continue
+ struct_cond_input = self.structcond_stage_model(struct_cond, t_replace)
+ else:
+ if start_T is not None:
+ if i > start_T:
+ continue
+ struct_cond_input = self.structcond_stage_model(struct_cond, ts)
+
+ if interfea_path is not None:
+ batch_list.append(struct_cond_input)
+
+ img = self.p_sample(img, cond, struct_cond_input, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, t_replace=t_replace)
+
+ if adain_fea is not None:
+ if i < 1:
+ img = adaptive_instance_normalization(img, adain_fea)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ if len(batch_list) > 0:
+ num_batch = batch_list[0]['64'].size(0)
+ for batch_i in range(num_batch):
+ batch64_list = []
+ batch32_list = []
+ for num_i in range(len(batch_list)):
+ batch64_list.append(cal_pca_components(batch_list[num_i]['64'][batch_i], 3))
+ batch32_list.append(cal_pca_components(batch_list[num_i]['32'][batch_i], 3))
+ batch64_list = np.array(batch64_list)
+ batch32_list = np.array(batch32_list)
+
+ batch64_list = batch64_list - np.min(batch64_list)
+ batch64_list = batch64_list / np.max(batch64_list)
+ batch32_list = batch32_list - np.min(batch32_list)
+ batch32_list = batch32_list / np.max(batch32_list)
+
+ total_num = batch64_list.shape[0]
+
+ for index in range(total_num):
+ os.makedirs(os.path.join(interfea_path, 'fea_'+str(batch_i)+'_64'), exist_ok=True)
+ cur_path = os.path.join(interfea_path, 'fea_'+str(batch_i)+'_64', 'step_'+str(total_num-index)+'.png')
+ visualize_fea(cur_path, batch64_list[index])
+ os.makedirs(os.path.join(interfea_path, 'fea_'+str(batch_i)+'_32'), exist_ok=True)
+ cur_path = os.path.join(interfea_path, 'fea_'+str(batch_i)+'_32', 'step_'+str(total_num-index)+'.png')
+ visualize_fea(cur_path, batch32_list[index])
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ def _gaussian_weights(self, tile_width, tile_height, nbatches):
+ """Generates a gaussian mask of weights for tile contributions"""
+ from numpy import pi, exp, sqrt
+ import numpy as np
+
+ latent_width = tile_width
+ latent_height = tile_height
+
+ var = 0.01
+ midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
+ x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
+ midpoint = latent_height / 2
+ y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
+
+ weights = np.outer(y_probs, x_probs)
+ return torch.tile(torch.tensor(weights, device=self.betas.device), (nbatches, self.configs.model.params.channels, 1, 1))
+
+ @torch.no_grad()
+ def p_sample_loop_canvas(self, cond, struct_cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None, time_replace=None, adain_fea=None, interfea_path=None, tile_size=64, tile_overlap=32, batch_size=4):
+
+ assert tile_size is not None
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = batch_size
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ tile_weights = self._gaussian_weights(tile_size, tile_size, 1)
+
+ for i in iterator:
+ if time_replace is None or time_replace == 1000:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ t_replace=None
+ else:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ t_replace = repeat(torch.tensor([self.ori_timesteps[i]]), '1 -> b', b=batch_size)
+ t_replace = t_replace.long().to(device)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ if interfea_path is not None:
+ for batch_i in range(struct_cond_input['64'].size(0)):
+ os.makedirs(os.path.join(interfea_path, 'fea_'+str(batch_i)+'_64'), exist_ok=True)
+ cur_path = os.path.join(interfea_path, 'fea_'+str(batch_i)+'_64', 'step_'+str(i)+'.png')
+ visualize_fea(cur_path, struct_cond_input['64'][batch_i, 0])
+ os.makedirs(os.path.join(interfea_path, 'fea_'+str(batch_i)+'_32'), exist_ok=True)
+ cur_path = os.path.join(interfea_path, 'fea_'+str(batch_i)+'_32', 'step_'+str(i)+'.png')
+ visualize_fea(cur_path, struct_cond_input['32'][batch_i, 0])
+
+ img = self.p_sample_canvas(img, cond, struct_cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, t_replace=t_replace,
+ tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, tile_weights=tile_weights)
+
+ if adain_fea is not None:
+ if i < 1:
+ img = adaptive_instance_normalization(img, adain_fea)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, struct_cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None, time_replace=None, adain_fea=None, interfea_path=None, start_T=None, **kwargs):
+
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size//8, self.image_size//8)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ struct_cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0, time_replace=time_replace, adain_fea=adain_fea, interfea_path=interfea_path, start_T=start_T)
+
+ @torch.no_grad()
+ def sample_canvas(self, cond, struct_cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None, time_replace=None, adain_fea=None, interfea_path=None, tile_size=64, tile_overlap=32, batch_size_sample=4, log_every_t=None, **kwargs):
+
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size//8, self.image_size//8)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key] if not isinstance(cond[key], list) else
+ list(map(lambda x: x, cond[key])) for key in cond}
+ else:
+ cond = [c for c in cond] if isinstance(cond, list) else cond
+ return self.p_sample_loop_canvas(cond,
+ struct_cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0, time_replace=time_replace, adain_fea=adain_fea, interfea_path=interfea_path, tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size_sample, log_every_t=log_every_t)
+
+ @torch.no_grad()
+ def sample_log(self,cond,struct_cond,batch_size,ddim, ddim_steps,**kwargs):
+
+ if ddim:
+ raise NotImplementedError
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size//8, self.image_size//8)
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
+ shape,cond,verbose=False,**kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, struct_cond=struct_cond, batch_size=batch_size,
+ return_intermediates=True,**kwargs)
+
+ return samples, intermediates
+
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
+ plot_diffusion_rows=False, **kwargs):
+
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c_lq, z_gt, x, gt, yrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N, val=True)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ if self.test_gt:
+ log["gt"] = gt
+ else:
+ log["inputs"] = x
+ log["reconstruction"] = gt
+ log["recon_lq"] = self.decode_first_stage(z)
+
+ c = self.cond_stage_model(c_lq)
+ if self.test_gt:
+ struct_cond = z_gt
+ else:
+ struct_cond = z
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ noise = torch.randn_like(z)
+ ddim_sampler = DDIMSampler(self)
+ with self.ema_scope("Plotting"):
+ if self.time_replace is not None:
+ cur_time_step=self.time_replace
+ else:
+ cur_time_step = 1000
+
+ samples, z_denoise_row = self.sample(cond=c, struct_cond=struct_cond, batch_size=N, timesteps=cur_time_step, return_intermediates=True, time_replace=self.time_replace)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ with self.ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c,struct_cond=struct_cond,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta,
+ quantize_denoised=True, x_T=x_T)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if inpaint:
+ assert NotImplementedError
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with self.ema_scope("Plotting Inpaint"):
+
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ with self.ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with self.ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c, struct_cond=struct_cond,
+ shape=(self.channels, self.image_size//8, self.image_size//8),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ params = params + list(self.cond_stage_model.parameters())
+ params = params + list(self.structcond_stage_model.parameters())
+ if self.learn_logvar:
+ assert not self.learn_logvar
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, struct_cond=None, seg_cond=None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ cc = torch.cat(c_crossattn, 1)
+ if seg_cond is None:
+ out = self.diffusion_model(x, t, context=cc, struct_cond=struct_cond)
+ else:
+ out = self.diffusion_model(x, t, context=cc, struct_cond=struct_cond, seg_cond=seg_cond)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+class Layout2ImgDiffusion(LatentDiffusion):
+ # TODO: move all layout-specific hacks to this class
+ def __init__(self, cond_stage_key, *args, **kwargs):
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+
+ def log_images(self, batch, N=8, *args, **kwargs):
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+
+ key = 'train' if self.training else 'validation'
+ dset = self.trainer.datamodule.datasets[key]
+ mapper = dset.conditional_builders[self.cond_stage_key]
+
+ bbox_imgs = []
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
+ bbox_imgs.append(bboximg)
+
+ cond_img = torch.stack(bbox_imgs, dim=0)
+ logs['bbox_image'] = cond_img
+ return logs
diff --git a/StableSR/ldm/models/diffusion/ddpm_inv.py b/StableSR/ldm/models/diffusion/ddpm_inv.py
new file mode 100644
index 0000000000000000000000000000000000000000..457057c26ce52f03c369e1de4c0f59effed9b0d6
--- /dev/null
+++ b/StableSR/ldm/models/diffusion/ddpm_inv.py
@@ -0,0 +1,1548 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+
+import torch.nn as nn
+import os
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager
+from functools import partial
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ embedding_reg_weight=0.,
+ unfreeze_model=False,
+ model_lr=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+ self.embedding_reg_weight = embedding_reg_weight
+
+ self.unfreeze_model = unfreeze_model
+ self.model_lr = model_lr
+
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ else:
+ raise NotImplementedError("mu not supported")
+ # TODO how to choose this term
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ else:
+ raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ personalization_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ *args, **kwargs):
+
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__':
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+
+
+ if not self.unfreeze_model:
+ self.cond_stage_model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+
+ self.model.eval()
+ self.model.train = disabled_train
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ self.embedding_manager = self.instantiate_embedding_manager(personalization_config, self.cond_stage_model)
+
+ for param in self.embedding_manager.embedding_parameters():
+ param.requires_grad = True
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+
+ def instantiate_embedding_manager(self, config, embedder):
+ model = instantiate_from_config(config, embedder=embedder)
+
+ if config.params.get("embedding_manager_ckpt", None): # do not load if missing OR empty string
+ model.load(config.params.embedding_manager_ckpt)
+
+ return model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if self.model.conditioning_key is not None:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox']:
+ xc = batch[cond_key]
+ elif cond_key == 'class_label':
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ # import pudb; pudb.set_trace()
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ # same as above but without decorator
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ uf = self.split_input_params["vqf"]
+ bs, nc, h, w = z.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
+
+ z = unfold(z) # (bn, nc * prod(**ks), L)
+ # 1. Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ # 2. apply model loop over last dim
+ if isinstance(self.first_stage_model, VQModelInterface):
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
+ force_not_quantize=predict_cids or force_not_quantize)
+ for i in range(z.shape[-1])]
+ else:
+
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
+ o = o * weighting
+ # Reverse 1. reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
+ return decoded
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ else:
+ if isinstance(self.first_stage_model, VQModelInterface):
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
+ else:
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ if hasattr(self, "split_input_params"):
+ if self.split_input_params["patch_distributed_vq"]:
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+ df = self.split_input_params["vqf"]
+ self.split_input_params['original_image_size'] = x.shape[-2:]
+ bs, nc, h, w = x.shape
+ if ks[0] > h or ks[1] > w:
+ ks = (min(ks[0], h), min(ks[1], w))
+ print("reducing Kernel")
+
+ if stride[0] > h or stride[1] > w:
+ stride = (min(stride[0], h), min(stride[1], w))
+ print("reducing stride")
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
+ z = unfold(x) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
+ for i in range(z.shape[-1])]
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ decoded = fold(o)
+ decoded = decoded / normalization
+ return decoded
+
+ else:
+ return self.first_stage_model.encode(x)
+ else:
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
+ def rescale_bbox(bbox):
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ return x0, y0, w, h
+
+ return [rescale_bbox(b) for b in bboxes]
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+
+ if isinstance(cond, dict):
+ # hybrid case, cond is exptected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ if hasattr(self, "split_input_params"):
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
+ assert not return_ids
+ ks = self.split_input_params["ks"] # eg. (128, 128)
+ stride = self.split_input_params["stride"] # eg. (64, 64)
+
+ h, w = x_noisy.shape[-2:]
+
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
+
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
+ # Reshape to img shape
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
+
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
+ c_key = next(iter(cond.keys())) # get key
+ c = next(iter(cond.values())) # get value
+ assert (len(c) == 1) # todo extend to list with more than one elem
+ c = c[0] # get element
+
+ c = unfold(c)
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
+
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
+
+ elif self.cond_stage_key == 'coordinates_bbox':
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
+
+ # assuming padding of unfold is always 0 and its dilation is always 1
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
+ # as we are operating on latents, we need the factor from the original image size to the
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
+ rescale_latent = 2 ** (num_downs)
+
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
+ # need to rescale the tl patch coordinates to be in between (0,1)
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
+ for patch_nr in range(z.shape[-1])]
+
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
+ patch_limits = [(x_tl, y_tl,
+ rescale_latent * ks[0] / full_img_w,
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
+
+ # tokenize crop coordinates for the bounding boxes of the respective patches
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
+ print(patch_limits_tknzd[0].shape)
+ # cut tknzd crop position from conditioning
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
+ print(cut_cond.shape)
+
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
+ print(adapted_cond.shape)
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
+ print(adapted_cond.shape)
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
+ print(adapted_cond.shape)
+
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
+
+ else:
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
+
+ # apply model by loop over crops
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
+ assert not isinstance(output_list[0],
+ tuple) # todo cant deal with multiple model outputs check this never happens
+
+ o = torch.stack(output_list, axis=-1)
+ o = o * weighting
+ # Reverse reshape to img shape
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
+ # stitch crops together
+ x_recon = fold(o) / normalization
+
+ else:
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ if self.embedding_reg_weight > 0:
+ loss_embedding_reg = self.embedding_manager.embedding_to_coarse_loss().mean()
+
+ loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg})
+
+ loss += (self.embedding_reg_weight * loss_embedding_reg)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None,**kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
+
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
+ shape,cond,verbose=False,**kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True,**kwargs)
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
+ plot_diffusion_rows=False, **kwargs):
+
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
+ log["conditioning"] = xc
+ elif self.cond_stage_key == 'class_label':
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ uc = self.get_learned_conditioning(len(c) * [""])
+ sample_scaled, _ = self.sample_log(cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=5.0,
+ unconditional_conditioning=uc)
+ log["samples_scaled"] = self.decode_first_stage(sample_scaled)
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with self.ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
+ ddim_steps=ddim_steps,eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with self.ema_scope("Plotting Inpaint"):
+
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ with self.ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with self.ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+
+ if self.embedding_manager is not None: # If using textual inversion
+ embedding_params = list(self.embedding_manager.embedding_parameters())
+
+ if self.unfreeze_model: # Are we allowing the base model to train? If so, set two different parameter groups.
+ model_params = list(self.cond_stage_model.parameters()) + list(self.model.parameters())
+ opt = torch.optim.AdamW([{"params": embedding_params, "lr": lr}, {"params": model_params}], lr=self.model_lr)
+ else: # Otherwise, train only embedding
+ opt = torch.optim.AdamW(embedding_params, lr=lr)
+ else:
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+
+ opt = torch.optim.AdamW(params, lr=lr)
+
+ return opt
+
+ def configure_opt_embedding(self):
+
+ self.cond_stage_model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+
+ self.model.eval()
+ self.model.train = disabled_train
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ for param in self.embedding_manager.embedding_parameters():
+ param.requires_grad = True
+
+ lr = self.learning_rate
+ params = list(self.embedding_manager.embedding_parameters())
+ return torch.optim.AdamW(params, lr=lr)
+
+ def configure_opt_model(self):
+
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = True
+
+ for param in self.model.parameters():
+ param.requires_grad = True
+
+ for param in self.embedding_manager.embedding_parameters():
+ param.requires_grad = True
+
+ model_params = list(self.cond_stage_model.parameters()) + list(self.model.parameters())
+ embedding_params = list(self.embedding_manager.embedding_parameters())
+ return torch.optim.AdamW([{"params": embedding_params, "lr": self.learning_rate}, {"params": model_params}], lr=self.model_lr)
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+ @rank_zero_only
+ def on_save_checkpoint(self, checkpoint):
+
+ if not self.unfreeze_model: # If we are not tuning the model itself, zero-out the checkpoint content to preserve memory.
+ checkpoint.clear()
+
+ if os.path.isdir(self.trainer.checkpoint_callback.dirpath):
+ self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, "embeddings.pt"))
+
+ self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, f"embeddings_gs-{self.global_step}.pt"))
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class Layout2ImgDiffusion(LatentDiffusion):
+ # TODO: move all layout-specific hacks to this class
+ def __init__(self, cond_stage_key, *args, **kwargs):
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
+
+ def log_images(self, batch, N=8, *args, **kwargs):
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
+
+ key = 'train' if self.training else 'validation'
+ dset = self.trainer.datamodule.datasets[key]
+ mapper = dset.conditional_builders[self.cond_stage_key]
+
+ bbox_imgs = []
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
+ bbox_imgs.append(bboximg)
+
+ cond_img = torch.stack(bbox_imgs, dim=0)
+ logs['bbox_image'] = cond_img
+ return logs
diff --git a/StableSR/ldm/models/diffusion/plms.py b/StableSR/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..78eeb1003aa45d27bdbfc6b4a1d7ccbff57cd2e3
--- /dev/null
+++ b/StableSR/ldm/models/diffusion/plms.py
@@ -0,0 +1,236 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/StableSR/ldm/models/respace.py b/StableSR/ldm/models/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..077653b08ff9af56955914af0478f110b238848d
--- /dev/null
+++ b/StableSR/ldm/models/respace.py
@@ -0,0 +1,116 @@
+import numpy as np
+import torch as th
+
+# from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim"):])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")] #[250,]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+# class SpacedDiffusion(GaussianDiffusion):
+# """
+# A diffusion process which can skip steps in a base diffusion process.
+#
+# :param use_timesteps: a collection (sequence or set) of timesteps from the
+# original diffusion process to retain.
+# :param kwargs: the kwargs to create the base diffusion process.
+# """
+#
+# def __init__(self, use_timesteps, **kwargs):
+# self.use_timesteps = set(use_timesteps)
+# self.timestep_map = []
+# self.original_num_steps = len(kwargs["betas"])
+#
+# base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+# last_alpha_cumprod = 1.0
+# new_betas = []
+# for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+# if i in self.use_timesteps:
+# new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+# last_alpha_cumprod = alpha_cumprod
+# self.timestep_map.append(i)
+# kwargs["betas"] = np.array(new_betas)
+# super().__init__(**kwargs)
+#
+# def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
+# return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+#
+# def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
+# return super().training_losses(self._wrap_model(model), *args, **kwargs)
+#
+# def _wrap_model(self, model):
+# if isinstance(model, _WrappedModel):
+# return model
+# return _WrappedModel(
+# model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
+# )
+#
+# def _scale_timesteps(self, t):
+# # Scaling is done by the wrapped model.
+# return t
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ if self.rescale_timesteps:
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/StableSR/ldm/modules/attention.py b/StableSR/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b11a9ec385e28dc2161c02faa642950df0cfac
--- /dev/null
+++ b/StableSR/ldm/modules/attention.py
@@ -0,0 +1,412 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+# CrossAttn precision handling
+import os
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=False):
+ super().__init__()
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+class BasicTransformerBlockV2(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention
+ }
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)]
+ )
+
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
+ x = self.proj_out(x)
+ return x + x_in
+
+class SpatialTransformerV2(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False, use_linear=False,
+ use_checkpoint=False):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlockV2(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/StableSR/ldm/modules/diffusionmodules/__init__.py b/StableSR/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StableSR/ldm/modules/diffusionmodules/model.py b/StableSR/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ae94c06bfb48f1cc189de8fcf1050d69c8993c3
--- /dev/null
+++ b/StableSR/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,1103 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from ldm.util import instantiate_from_config
+from ldm.modules.attention import LinearAttention
+
+from basicsr.archs.arch_util import default_init_weights, make_layer, pixel_unshuffle
+from basicsr.archs.rrdbnet_arch import RRDB
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.attention_op: Optional[Any] = None
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q, k, v = map(
+ lambda t:t.reshape(b, t.shape[1], t.shape[2]*t.shape[3], 1)
+ .squeeze(3)
+ .permute(0,2,1)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, scale=(int(c)**(-0.5)), op=self.attention_op)
+
+ h_ = (
+ out.permute(0,2,1)
+ .unsqueeze(3)
+ .reshape(b, c, h, w)
+ )
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ if XFORMERS_IS_AVAILBLE:
+ return MemoryEfficientAttnBlock(in_channels)
+ else:
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, return_fea=False):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ fea_list = []
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if return_fea:
+ if i_level==1 or i_level==2:
+ fea_list.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+
+ if return_fea:
+ return h, fea_list
+
+ return h
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+class Decoder_Mix(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", num_fuse_block=2, fusion_w=1.0, **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+ self.fusion_w = fusion_w
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+
+ if i_level != self.num_resolutions-1:
+ if i_level != 0:
+ fuse_layer = Fuse_sft_block_RRDB(in_ch=block_out, out_ch=block_out, num_block=num_fuse_block)
+ setattr(self, 'fusion_layer_{}'.format(i_level), fuse_layer)
+
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z, enc_fea):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+
+ if i_level != self.num_resolutions-1 and i_level != 0:
+ cur_fuse_layer = getattr(self, 'fusion_layer_{}'.format(i_level))
+ h = cur_fuse_layer(enc_fea[i_level-1], h, self.fusion_w)
+
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = Normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ x = nonlinearity(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = nonlinearity(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+class Fuse_sft_block_RRDB(nn.Module):
+ def __init__(self, in_ch, out_ch, num_block=1, num_grow_ch=32):
+ super().__init__()
+ self.encode_enc_1 = ResBlock(2*in_ch, in_ch)
+ self.encode_enc_2 = make_layer(RRDB, num_block, num_feat=in_ch, num_grow_ch=num_grow_ch)
+ self.encode_enc_3 = ResBlock(in_ch, out_ch)
+
+ def forward(self, enc_feat, dec_feat, w=1):
+ enc_feat = self.encode_enc_1(torch.cat([enc_feat, dec_feat], dim=1))
+ enc_feat = self.encode_enc_2(enc_feat)
+ enc_feat = self.encode_enc_3(enc_feat)
+ residual = w * enc_feat
+ out = dec_feat + residual
+ return out
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ if XFORMERS_IS_AVAILBLE:
+ self.attn = MemoryEfficientAttnBlock(mid_channels)
+ else:
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(self, ch_mult:list, in_channels,
+ pretrained_model:nn.Module=None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.,
+ pretrained_config=None):
+ super().__init__()
+ if pretrained_config is None:
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
+ self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
+ stride=1,padding=1)
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+
+ @torch.no_grad()
+ def encode_with_pretrained(self,x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self,x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model,self.downsampler):
+ z = submodel(z,temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z,'b c h w -> b (h w) c')
+ return z
diff --git a/StableSR/ldm/modules/diffusionmodules/openaimodel.py b/StableSR/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aa3f5b26db1117564de1f41e16353b1858d732b
--- /dev/null
+++ b/StableSR/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,1541 @@
+from abc import abstractmethod
+from functools import partial
+import math
+from typing import Iterable
+import torch
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer, SpatialTransformerV2
+from ldm.modules.spade import SPADE
+
+from basicsr.archs.stylegan2_arch import ConvLayer, EqualConv2d
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+def exists(val):
+ return val is not None
+
+def cal_fea_cossim(fea_1, fea_2, save_dir=None):
+ cossim_fuc = nn.CosineSimilarity(dim=-1, eps=1e-6)
+ if save_dir is None:
+ save_dir_1 = './cos_sim64_1_not.txt'
+ save_dir_2 = './cos_sim64_2_not.txt'
+ b, c, h, w = fea_1.size()
+ fea_1 = fea_1.reshape(b, c, h*w)
+ fea_2 = fea_2.reshape(b, c, h*w)
+ cos_sim = cossim_fuc(fea_1, fea_2)
+ cos_sim = cos_sim.data.cpu().numpy()
+ with open(save_dir_1, "a") as my_file:
+ my_file.write(str(np.mean(cos_sim[0])) + "\n")
+ # with open(save_dir_2, "a") as my_file:
+ # my_file.write(str(np.mean(cos_sim[1])) + "\n")
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+class TimestepBlockDual(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb, cond):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+class TimestepBlock3cond(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb, s_cond, seg_cond):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None, struct_cond=None, seg_cond=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer) or isinstance(layer, SpatialTransformerV2):
+ assert context is not None
+ x = layer(x, context)
+ elif isinstance(layer, TimestepBlockDual):
+ assert struct_cond is not None
+ x = layer(x, emb, struct_cond)
+ elif isinstance(layer, TimestepBlock3cond):
+ assert seg_cond is not None
+ x = layer(x, emb, struct_cond, seg_cond)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+
+ if self.out_channels % 32 == 0:
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+ else:
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels, self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+class ResBlockDual(TimestepBlockDual):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ semb_channels,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+
+ # Here we use the built component of SPADE, rather than SFT. Should have no significant influence on the performance.
+ self.spade = SPADE(self.out_channels, semb_channels)
+
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb, s_cond):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb, s_cond), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb, s_cond):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ h = self.spade(h, s_cond)
+ return self.skip_connection(x) + h
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ if XFORMERS_IS_AVAILBLE:
+ q, k, v = map(
+ lambda t:t.permute(0,2,1)
+ .contiguous(),
+ (q, k, v),
+ )
+ # actually compute the attention, what we cannot get enough of
+ a = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+ a = (
+ a.permute(0,2,1)
+ .reshape(bs, -1, length)
+ )
+ else:
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ a = a.reshape(bs, -1, length)
+ return a
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ if XFORMERS_IS_AVAILBLE:
+ q, k, v = map(
+ lambda t:t.permute(0,2,1)
+ .contiguous(),
+ (q, k, v),
+ )
+ # actually compute the attention, what we cannot get enough of
+ a = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+ a = (
+ a.permute(0,2,1)
+ .reshape(bs, -1, length)
+ )
+ else:
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ a = a.reshape(bs, -1, length)
+ return a
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+class UNetModelDualcondV2(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ semb_channels=None
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlockDual(
+ ch,
+ time_embed_dim,
+ dropout,
+ semb_channels=semb_channels,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformerV2(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlockDual(
+ ch,
+ time_embed_dim,
+ dropout,
+ semb_channels=semb_channels,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlockDual(
+ ch,
+ time_embed_dim,
+ dropout,
+ semb_channels=semb_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformerV2( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlockDual(
+ ch,
+ time_embed_dim,
+ dropout,
+ semb_channels=semb_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlockDual(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ semb_channels=semb_channels,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformerV2(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlockDual(
+ ch,
+ time_embed_dim,
+ dropout,
+ semb_channels=semb_channels,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, struct_cond=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context, struct_cond)
+ hs.append(h)
+ h = self.middle_block(h, emb, context, struct_cond)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context, struct_cond)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+class EncoderUNetModelWT(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ *args,
+ **kwargs
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = []
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ input_block_chans.append(ch)
+ self._feature_size += ch
+ self.input_block_chans = input_block_chans
+
+ self.fea_tran = nn.ModuleList([])
+
+ for i in range(len(input_block_chans)):
+ self.fea_tran.append(
+ ResBlock(
+ input_block_chans[i],
+ time_embed_dim,
+ dropout,
+ out_channels=out_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ result_list = []
+ results = {}
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ last_h = h
+ h = module(h, emb)
+ if h.size(-1) != last_h.size(-1):
+ result_list.append(last_h)
+ h = self.middle_block(h, emb)
+ result_list.append(h)
+
+ assert len(result_list) == len(self.fea_tran)
+
+ for i in range(len(result_list)):
+ results[str(result_list[i].size(-1))] = self.fea_tran[i](result_list[i], emb)
+
+ return results
diff --git a/StableSR/ldm/modules/diffusionmodules/util.py b/StableSR/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e77a8150d81f67ee42885098bf5d9a52a2681669
--- /dev/null
+++ b/StableSR/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,267 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels, norm_channel=32):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(norm_channel, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
diff --git a/StableSR/ldm/modules/distributions/__init__.py b/StableSR/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StableSR/ldm/modules/distributions/distributions.py b/StableSR/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/StableSR/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/StableSR/ldm/modules/ema.py b/StableSR/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..450cc844c0ce0353fb7cee371440cb901864d1a5
--- /dev/null
+++ b/StableSR/ldm/modules/ema.py
@@ -0,0 +1,78 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
+ else torch.tensor(-1,dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ #remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.','')
+ self.m_name2s_name.update({name:s_name})
+ self.register_buffer(s_name,p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self,model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ pass
+ # assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ pass
+ # assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/StableSR/ldm/modules/embedding_manager.py b/StableSR/ldm/modules/embedding_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c5f97bd9f151bc1c508f56bd7ccfb4509aaea82
--- /dev/null
+++ b/StableSR/ldm/modules/embedding_manager.py
@@ -0,0 +1,161 @@
+import torch
+from torch import nn
+
+from ldm.data.personalized import per_img_token_list
+from transformers import CLIPTokenizer
+from functools import partial
+
+DEFAULT_PLACEHOLDER_TOKEN = ["*"]
+
+PROGRESSIVE_SCALE = 2000
+
+def get_clip_token_for_string(tokenizer, string):
+ batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"]
+ assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
+
+ return tokens[0, 1]
+
+def get_bert_token_for_string(tokenizer, string):
+ token = tokenizer(string)
+ assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
+
+ token = token[0, 1]
+
+ return token
+
+def get_embedding_for_clip_token(embedder, token):
+ return embedder(token.unsqueeze(0))[0, 0]
+
+
+class EmbeddingManager(nn.Module):
+ def __init__(
+ self,
+ embedder,
+ placeholder_strings=None,
+ initializer_words=None,
+ per_image_tokens=False,
+ num_vectors_per_token=1,
+ progressive_words=False,
+ **kwargs
+ ):
+ super().__init__()
+
+ self.string_to_token_dict = {}
+
+ self.string_to_param_dict = nn.ParameterDict()
+
+ self.initial_embeddings = nn.ParameterDict() # These should not be optimized
+
+ self.progressive_words = progressive_words
+ self.progressive_counter = 0
+
+ self.max_vectors_per_token = num_vectors_per_token
+
+ if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
+ self.is_clip = True
+ get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
+ get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)
+ token_dim = 768
+ else: # using LDM's BERT encoder
+ self.is_clip = False
+ get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
+ get_embedding_for_tkn = embedder.transformer.token_emb
+ token_dim = 1280
+
+ if per_image_tokens:
+ placeholder_strings.extend(per_img_token_list)
+
+ for idx, placeholder_string in enumerate(placeholder_strings):
+
+ token = get_token_for_string(placeholder_string)
+
+ if initializer_words and idx < len(initializer_words):
+ init_word_token = get_token_for_string(initializer_words[idx])
+
+ with torch.no_grad():
+ init_word_embedding = get_embedding_for_tkn(init_word_token.cpu())
+
+ token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
+ self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
+ else:
+ token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
+
+ self.string_to_token_dict[placeholder_string] = token
+ self.string_to_param_dict[placeholder_string] = token_params
+
+ def forward(
+ self,
+ tokenized_text,
+ embedded_text,
+ ):
+ b, n, device = *tokenized_text.shape, tokenized_text.device
+
+ for placeholder_string, placeholder_token in self.string_to_token_dict.items():
+
+ placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
+
+ if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
+ placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
+ embedded_text[placeholder_idx] = placeholder_embedding
+ else: # otherwise, need to insert and keep track of changing indices
+ if self.progressive_words:
+ self.progressive_counter += 1
+ max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
+ else:
+ max_step_tokens = self.max_vectors_per_token
+
+ num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
+
+ placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
+
+ if placeholder_rows.nelement() == 0:
+ continue
+
+ sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
+ sorted_rows = placeholder_rows[sort_idx]
+
+ for idx in range(len(sorted_rows)):
+ row = sorted_rows[idx]
+ col = sorted_cols[idx]
+
+ new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
+ new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
+
+ embedded_text[row] = new_embed_row
+ tokenized_text[row] = new_token_row
+
+ return embedded_text
+
+ def save(self, ckpt_path):
+ torch.save({"string_to_token": self.string_to_token_dict,
+ "string_to_param": self.string_to_param_dict}, ckpt_path)
+
+ def load(self, ckpt_path):
+ ckpt = torch.load(ckpt_path, map_location='cpu')
+
+ self.string_to_token_dict = ckpt["string_to_token"]
+ self.string_to_param_dict = ckpt["string_to_param"]
+
+ def get_embedding_norms_squared(self):
+ all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
+ param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
+
+ return param_norm_squared
+
+ def embedding_parameters(self):
+ return self.string_to_param_dict.parameters()
+
+ def embedding_to_coarse_loss(self):
+
+ loss = 0.
+ num_embeddings = len(self.initial_embeddings)
+
+ for key in self.initial_embeddings:
+ optimized = self.string_to_param_dict[key]
+ coarse = self.initial_embeddings[key].clone().to(optimized.device)
+
+ loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
+
+ return loss
diff --git a/StableSR/ldm/modules/encoders/__init__.py b/StableSR/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/StableSR/ldm/modules/encoders/modules.py b/StableSR/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2ac91a1205d6746e75ba173170080f2f37ce377
--- /dev/null
+++ b/StableSR/ldm/modules/encoders/modules.py
@@ -0,0 +1,484 @@
+import torch
+import torch.nn as nn
+from functools import partial
+import clip
+from einops import rearrange, repeat
+import transformers
+from transformers import CLIPTokenizer, CLIPTextModel
+import kornia
+
+from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
+from .transformer_utils import CLIPTextTransformer_M
+import open_clip
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+
+ def forward(self, batch, key=None):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ c = self.embedding(c)
+ return c
+
+
+class TransformerEmbedder(AbstractEncoder):
+ """Some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
+ super().__init__()
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer))
+
+ def forward(self, tokens):
+ tokens = tokens.to(self.device) # meh
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, x):
+ return self(x)
+
+
+class BERTTokenizer(AbstractEncoder):
+ """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
+ def __init__(self, device="cuda", vq_interface=True, max_length=77):
+ super().__init__()
+ from transformers import BertTokenizerFast # TODO: add to reuquirements
+ self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
+ self.device = device
+ self.vq_interface = vq_interface
+ self.max_length = max_length
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ return tokens
+
+ @torch.no_grad()
+ def encode(self, text):
+ tokens = self(text)
+ if not self.vq_interface:
+ return tokens
+ return None, None, [None, None, tokens]
+
+ def decode(self, text):
+ return text
+
+
+class BERTEmbedder(AbstractEncoder):
+ """Uses the BERT tokenizr model and add some transformer encoder layers"""
+ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
+ device="cuda",use_tokenizer=True, embedding_dropout=0.0):
+ super().__init__()
+ self.use_tknz_fn = use_tokenizer
+ if self.use_tknz_fn:
+ self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
+ self.device = device
+ self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
+ emb_dropout=embedding_dropout)
+
+ def forward(self, text):
+ if self.use_tknz_fn:
+ tokens = self.tknz_fn(text)#.to(self.device)
+ else:
+ tokens = text
+ z = self.transformer(tokens, return_embeddings=True)
+ return z
+
+ def encode(self, text):
+ # output of length 77
+ return self(text)
+
+
+class SpatialRescaler(nn.Module):
+ def __init__(self,
+ n_stages=1,
+ method='bilinear',
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None
+ if self.remap_output:
+ print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
+ self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
+
+ def forward(self,x):
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ #"pooled",
+ "last",
+ "penultimate"
+ ]
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+ freeze=True, layer="last"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+class FinetuningCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
+ super().__init__()
+ setattr(transformers.models.clip.modeling_clip,"CLIPTextTransformer", CLIPTextTransformer_M)
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ # self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ # batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ # return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ # tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(text)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+class FrozenCLIPTextEmbedder(nn.Module):
+ """
+ Uses the CLIP transformer encoder for text.
+ """
+ def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
+ super().__init__()
+ self.model, _ = clip.load(version, jit=False, device="cpu")
+ self.device = device
+ self.max_length = max_length
+ self.n_repeat = n_repeat
+ self.normalize = normalize
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = clip.tokenize(text).to(self.device)
+ z = self.model.encode_text(tokens)
+ if self.normalize:
+ z = z / torch.linalg.norm(z, dim=1, keepdim=True)
+ return z
+
+ def encode(self, text):
+ z = self(text)
+ if z.ndim==2:
+ z = z[:, None, :]
+ z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
+ return z
+
+class FrozenClipImageEmbedder(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ return self.model.encode_image(self.preprocess(x))
+
+class FrozenClipImageEmbedderNew(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model,
+ in_channels=1024,
+ output_channels=768,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ clip_model, _ = clip.load(name=model, device=device, jit=jit)
+ self.encoder = clip_model.visual
+ self.linear = nn.Linear(in_channels, output_channels)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ # x = kornia.geometry.resize(x, (224, 224),
+ # interpolation='bicubic',align_corners=True,
+ # antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ x = self.encoder(self.preprocess(x)).float()
+ x = self.linear(x)
+ return x
+
+class ClipImageEmbedder(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ vision_layers=[2,2,2,2],
+ embed_dim=768,
+ vision_heads=64,
+ input_resolution=224,
+ vision_width=64,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ input_dim=3
+ ):
+ super().__init__()
+ from clip.model import ModifiedResNet
+ self.encoder = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=input_resolution,
+ width=vision_width,
+ input_dim=input_dim
+ )
+
+ # self.pixel_unshuffle = nn.PixelUnshuffle(2)
+
+ # self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ # self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ # def preprocess(self, x):
+ # # normalize to [0,1]
+ # x = (x + 1.) / 2.
+ # # renormalize according to clip
+ # x = kornia.enhance.normalize(x, self.mean, self.std)
+ #
+ # # return self.pixel_unshuffle(x)
+ # return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ x = self.encoder(x).float()
+ return x
+
+class ClipImageEmbedderOri(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model,
+ in_channels,
+ out_channels,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
+ self.freeze()
+
+ self.final_projector = nn.Linear(in_channels, out_channels)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ clip_fea = self.model.encode_image(self.preprocess(x)).float()
+ clip_fea = self.final_projector(clip_fea)
+ return clip_fea
+
+class ClipImage2TextEmbedder(nn.Module):
+ """
+ Uses the CLIP image encoder.
+ """
+ def __init__(
+ self,
+ model,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=False,
+ ):
+ super().__init__()
+ self.model, _ = clip.load(name=model, device=device, jit=jit)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x):
+ # x is assumed to be in range [-1,1]
+ return self.model.encode_image(self.preprocess(x))
+
+
+if __name__ == "__main__":
+ from ldm.util import count_params
+ model = FrozenCLIPEmbedder()
+ count_params(model, verbose=True)
diff --git a/StableSR/ldm/modules/encoders/transformer_utils.py b/StableSR/ldm/modules/encoders/transformer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3d90de216a12938c5f79336e8916d06f40988ef
--- /dev/null
+++ b/StableSR/ldm/modules/encoders/transformer_utils.py
@@ -0,0 +1,181 @@
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.models.clip.modeling_clip import CLIPTextTransformer
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from transformers.models.clip.configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
+from typing import Any, Optional, Tuple, Union
+from transformers.utils import (
+ add_start_docstrings_to_model_forward,
+ replace_return_docstrings,
+)
+
+
+CLIP_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+ Parameters:
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+CLIP_TEXT_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ [What are position IDs?](../glossary#position-ids)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+CLIP_VISION_INPUTS_DOCSTRING = r"""
+ Args:
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+CLIP_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+ Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ [What are attention masks?](../glossary#attention-mask)
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+ [What are position IDs?](../glossary#position-ids)
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
+ [`CLIPFeatureExtractor`]. See [`CLIPFeatureExtractor.__call__`] for details.
+ return_loss (`bool`, *optional*):
+ Whether or not to return the contrastive loss.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+class CLIPTextTransformer_M(CLIPTextTransformer):
+
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+ """
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is None:
+ raise ValueError("You have to specify either input_ids")
+
+ input_shape = input_ids.size()
+ # input_ids = input_ids.view(-1, input_shape[-1])
+
+ # hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
+ hidden_states = input_ids
+
+ bsz, seq_len, _ = input_shape
+ # CLIP's text model uses causal mask, prepare it here.
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
+ causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
+ hidden_states.device
+ )
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ attention_mask=attention_mask,
+ causal_attention_mask=causal_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
+
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
+ pooled_output = last_hidden_state[
+ torch.arange(last_hidden_state.shape[0], device=input_ids.device), torch.mean(input_ids, -1).to(torch.int).argmax(dim=-1)
+ ]
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+ def _build_causal_attention_mask(self, bsz, seq_len, dtype):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
+ mask.fill_(torch.tensor(torch.finfo(dtype).min))
+ mask.triu_(1) # zero out the lower diagonal
+ mask = mask.unsqueeze(1) # expand mask
+ return mask
diff --git a/StableSR/ldm/modules/image_degradation/__init__.py b/StableSR/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/StableSR/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/StableSR/ldm/modules/image_degradation/bsrgan.py b/StableSR/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/StableSR/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/StableSR/ldm/modules/image_degradation/bsrgan_light.py b/StableSR/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1f823996bf559e9b015ea9aa2b3cd38dd13af1
--- /dev/null
+++ b/StableSR/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,650 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/StableSR/ldm/modules/image_degradation/utils_image.py b/StableSR/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/StableSR/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/StableSR/ldm/modules/losses/__init__.py b/StableSR/ldm/modules/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..876d7c5bd6e3245ee77feb4c482b7a8143604ad5
--- /dev/null
+++ b/StableSR/ldm/modules/losses/__init__.py
@@ -0,0 +1 @@
+from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
diff --git a/StableSR/ldm/modules/losses/contperceptual.py b/StableSR/ldm/modules/losses/contperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa8da1cf344ab7ff8d7e5fd4deb0dbfeb54536e8
--- /dev/null
+++ b/StableSR/ldm/modules/losses/contperceptual.py
@@ -0,0 +1,151 @@
+import torch
+import torch.nn as nn
+
+from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
+
+
+class LPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_loss="hinge"):
+
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.kl_weight = kl_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train",
+ weights=None, return_dic=False):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights*nll_loss
+ weighted_nll_loss = torch.mean(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss) / nll_loss.shape[0]
+ if self.kl_weight>0:
+ kl_loss = posteriors.kl()
+ kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ if self.disc_factor > 0.0:
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ # assert not self.training
+ d_weight = torch.tensor(1.0) * self.discriminator_weight
+ else:
+ # d_weight = torch.tensor(0.0)
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ if self.kl_weight>0:
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ if return_dic:
+ loss_dic = {}
+ loss_dic['total_loss'] = loss.clone().detach().mean()
+ loss_dic['logvar'] = self.logvar.detach()
+ loss_dic['kl_loss'] = kl_loss.detach().mean()
+ loss_dic['nll_loss'] = nll_loss.detach().mean()
+ loss_dic['rec_loss'] = rec_loss.detach().mean()
+ loss_dic['d_weight'] = d_weight.detach()
+ loss_dic['disc_factor'] = torch.tensor(disc_factor)
+ loss_dic['g_loss'] = g_loss.detach().mean()
+ else:
+ loss = weighted_nll_loss + d_weight * disc_factor * g_loss
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ if return_dic:
+ loss_dic = {}
+ loss_dic["{}/total_loss".format(split)] = loss.clone().detach().mean()
+ loss_dic["{}/logvar".format(split)] = self.logvar.detach()
+ loss_dic['nll_loss'.format(split)] = nll_loss.detach().mean()
+ loss_dic['rec_loss'.format(split)] = rec_loss.detach().mean()
+ loss_dic['d_weight'.format(split)] = d_weight.detach()
+ loss_dic['disc_factor'.format(split)] = torch.tensor(disc_factor)
+ loss_dic['g_loss'.format(split)] = g_loss.detach().mean()
+
+ if return_dic:
+ return loss, log, loss_dic
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+
+ if return_dic:
+ loss_dic = {}
+ loss_dic["{}/disc_loss".format(split)] = d_loss.clone().detach().mean()
+ loss_dic["{}/logits_real".format(split)] = logits_real.detach().mean()
+ loss_dic["{}/logits_fake".format(split)] = logits_fake.detach().mean()
+ return d_loss, log, loss_dic
+
+ return d_loss, log
diff --git a/StableSR/ldm/modules/losses/vqperceptual.py b/StableSR/ldm/modules/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69981769e4bd5462600458c4fcf26620f7e4306
--- /dev/null
+++ b/StableSR/ldm/modules/losses/vqperceptual.py
@@ -0,0 +1,167 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from einops import repeat
+
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+
+def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
+ loss_real = (weights * loss_real).sum() / weights.sum()
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def measure_perplexity(predicted_indices, n_embed):
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
+
+def l1(x, y):
+ return torch.abs(x-y)
+
+
+def l2(x, y):
+ return torch.pow((x-y), 2)
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
+ pixel_loss="l1"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ assert perceptual_loss in ["lpips", "clips", "dists"]
+ assert pixel_loss in ["l1", "l2"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ if perceptual_loss == "lpips":
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
+ self.perceptual_loss = LPIPS().eval()
+ else:
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
+ self.perceptual_weight = perceptual_weight
+
+ if pixel_loss == "l1":
+ self.pixel_loss = l1
+ else:
+ self.pixel_loss = l2
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+ self.n_classes = n_classes
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
+ if not exists(codebook_loss):
+ codebook_loss = torch.tensor([0.]).to(inputs.device)
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ if predicted_indices is not None:
+ assert self.n_classes is not None
+ with torch.no_grad():
+ perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
+ log[f"{split}/perplexity"] = perplexity
+ log[f"{split}/cluster_usage"] = cluster_usage
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/StableSR/ldm/modules/spade.py b/StableSR/ldm/modules/spade.py
new file mode 100644
index 0000000000000000000000000000000000000000..72845bdfb5ac0139aaa509681208804dc8444e71
--- /dev/null
+++ b/StableSR/ldm/modules/spade.py
@@ -0,0 +1,111 @@
+"""
+Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
+Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+"""
+
+import re
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+# from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
+import torch.nn.utils.spectral_norm as spectral_norm
+
+from ldm.modules.diffusionmodules.util import normalization
+
+
+# Returns a function that creates a normalization function
+# that does not condition on semantic map
+def get_nonspade_norm_layer(opt, norm_type='instance'):
+ # helper function to get # output channels of the previous layer
+ def get_out_channel(layer):
+ if hasattr(layer, 'out_channels'):
+ return getattr(layer, 'out_channels')
+ return layer.weight.size(0)
+
+ # this function will be returned
+ def add_norm_layer(layer):
+ nonlocal norm_type
+ if norm_type.startswith('spectral'):
+ layer = spectral_norm(layer)
+ subnorm_type = norm_type[len('spectral'):]
+
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
+ return layer
+
+ # remove bias in the previous layer, which is meaningless
+ # since it has no effect after normalization
+ if getattr(layer, 'bias', None) is not None:
+ delattr(layer, 'bias')
+ layer.register_parameter('bias', None)
+
+ if subnorm_type == 'batch':
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
+ elif subnorm_type == 'sync_batch':
+ norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
+ elif subnorm_type == 'instance':
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
+ else:
+ raise ValueError('normalization layer %s is not recognized' % subnorm_type)
+
+ return nn.Sequential(layer, norm_layer)
+
+ return add_norm_layer
+
+
+# Creates SPADE normalization layer based on the given configuration
+# SPADE consists of two steps. First, it normalizes the activations using
+# your favorite normalization method, such as Batch Norm or Instance Norm.
+# Second, it applies scale and bias to the normalized output, conditioned on
+# the segmentation map.
+# The format of |config_text| is spade(norm)(ks), where
+# (norm) specifies the type of parameter-free normalization.
+# (e.g. syncbatch, batch, instance)
+# (ks) specifies the size of kernel in the SPADE module (e.g. 3x3)
+# Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5.
+# Also, the other arguments are
+# |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE
+# |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE
+class SPADE(nn.Module):
+ def __init__(self, norm_nc, label_nc, config_text='spadeinstance3x3'):
+ super().__init__()
+
+ assert config_text.startswith('spade')
+ parsed = re.search('spade(\D+)(\d)x\d', config_text)
+ param_free_norm_type = str(parsed.group(1))
+ ks = int(parsed.group(2))
+
+ self.param_free_norm = normalization(norm_nc)
+
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
+ nhidden = 128
+
+ pw = ks // 2
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
+ nn.ReLU()
+ )
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
+
+ def forward(self, x_dic, segmap_dic, size=None):
+
+ if size is None:
+ segmap = segmap_dic[str(x_dic.size(-1))]
+ x = x_dic
+ else:
+ x = x_dic[str(size)]
+ segmap = segmap_dic[str(size)]
+
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x)
+
+ # Part 2. produce scaling and bias conditioned on semantic map
+ # segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+
+ # apply scale and bias
+ out = normalized * (1 + gamma) + beta
+
+ return out
diff --git a/StableSR/ldm/modules/swinir.py b/StableSR/ldm/modules/swinir.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4a6ac8510f818997dc10ec0d4d0535b4f3c7130
--- /dev/null
+++ b/StableSR/ldm/modules/swinir.py
@@ -0,0 +1,854 @@
+# -----------------------------------------------------------------------------------
+# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
+# Originally Written by Ze Liu, Modified by Jingyun Liang.
+# -----------------------------------------------------------------------------------
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ attn_mask = self.calculate_mask(self.input_resolution)
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def calculate_mask(self, x_size):
+ # calculate attention mask for SW-MSA
+ H, W = x_size
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x, x_size):
+ H, W = x_size
+ B, L, C = x.shape
+ # assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+ if self.input_resolution == x_size:
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+ else:
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, x_size)
+ else:
+ x = blk(x, x_size)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class RSTB(nn.Module):
+ """Residual Swin Transformer Block (RSTB).
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+ img_size=224, patch_size=4, resi_connection='1conv'):
+ super(RSTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
+ norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
+ norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ H, W = self.input_resolution
+ flops += H * W * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ flops = 0
+ H, W = self.img_size
+ if self.norm is not None:
+ flops += H * W * self.embed_dim
+ return flops
+
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ B, HW, C = x.shape
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.num_feat * 3 * 9
+ return flops
+
+
+class SwinIR(nn.Module):
+ r""" SwinIR
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 64
+ patch_size (int | tuple(int)): Patch size. Default: 1
+ in_chans (int): Number of input image channels. Default: 3
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+ img_range: Image range. 1. or 255.
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+ """
+
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
+ **kwargs):
+ super(SwinIR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ # if in_chans == 3:
+ # rgb_mean = (0.4488, 0.4371, 0.4040)
+ # self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ # else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+ self.window_size = window_size
+
+ #####################################################################################################
+ ################################### 1, shallow feature extraction ###################################
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ #####################################################################################################
+ ################################### 2, deep feature extraction ######################################
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Swin Transformer blocks (RSTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(dim=embed_dim,
+ input_resolution=(patches_resolution[0],
+ patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection
+
+ )
+ self.layers.append(layer)
+ self.norm = norm_layer(self.num_features)
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ #####################################################################################################
+ ################################ 3, high quality image reconstruction ################################
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
+ nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ if self.upscale == 4:
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def check_image_size(self, x):
+ _, _, h, w = x.size()
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ return x
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, x_size)
+
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ H, W = x.shape[2:]
+ x = self.check_image_size(x)
+
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ if self.upscale == 4:
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+
+ return x[:, :, :H*self.upscale, :W*self.upscale]
+
+ def flops(self):
+ flops = 0
+ H, W = self.patches_resolution
+ flops += H * W * 3 * self.embed_dim * 9
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
+ flops += self.upsample.flops()
+ return flops
+
+
+if __name__ == '__main__':
+ upscale = 4
+ window_size = 8
+ height = (1024 // upscale // window_size + 1) * window_size
+ width = (720 // upscale // window_size + 1) * window_size
+ model = SwinIR(upscale=2, img_size=(height, width),
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
+ print(model)
+ print(height, width, model.flops() / 1e9)
+
+ x = torch.randn((1, 3, height, width))
+ x = model(x)
+ print(x.shape)
diff --git a/StableSR/ldm/modules/x_transformer.py b/StableSR/ldm/modules/x_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576
--- /dev/null
+++ b/StableSR/ldm/modules/x_transformer.py
@@ -0,0 +1,641 @@
+"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from functools import partial
+from inspect import isfunction
+from collections import namedtuple
+from einops import rearrange, repeat, reduce
+
+# constants
+
+DEFAULT_DIM_HEAD = 64
+
+Intermediates = namedtuple('Intermediates', [
+ 'pre_softmax_attn',
+ 'post_softmax_attn'
+])
+
+LayerIntermediates = namedtuple('Intermediates', [
+ 'hiddens',
+ 'attn_intermediates'
+])
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len):
+ super().__init__()
+ self.emb = nn.Embedding(max_seq_len, dim)
+ self.init_()
+
+ def init_(self):
+ nn.init.normal_(self.emb.weight, std=0.02)
+
+ def forward(self, x):
+ n = torch.arange(x.shape[1], device=x.device)
+ return self.emb(n)[None, :, :]
+
+
+class FixedPositionalEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, x, seq_dim=1, offset=0):
+ t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
+ sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
+ return emb[None, :, :]
+
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def always(val):
+ def inner(*args, **kwargs):
+ return val
+ return inner
+
+
+def not_equals(val):
+ def inner(x):
+ return x != val
+ return inner
+
+
+def equals(val):
+ def inner(x):
+ return x == val
+ return inner
+
+
+def max_neg_value(tensor):
+ return -torch.finfo(tensor.dtype).max
+
+
+# keyword argument helpers
+
+def pick_and_pop(keys, d):
+ values = list(map(lambda key: d.pop(key), keys))
+ return dict(zip(keys, values))
+
+
+def group_dict_by_key(cond, d):
+ return_val = [dict(), dict()]
+ for key in d.keys():
+ match = bool(cond(key))
+ ind = int(not match)
+ return_val[ind][key] = d[key]
+ return (*return_val,)
+
+
+def string_begins_with(prefix, str):
+ return str.startswith(prefix)
+
+
+def group_by_key_prefix(prefix, d):
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
+
+
+def groupby_prefix_and_trim(prefix, d):
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
+ return kwargs_without_prefix, kwargs
+
+
+# classes
+class Scale(nn.Module):
+ def __init__(self, value, fn):
+ super().__init__()
+ self.value = value
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.value, *rest)
+
+
+class Rezero(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+ self.g = nn.Parameter(torch.zeros(1))
+
+ def forward(self, x, **kwargs):
+ x, *rest = self.fn(x, **kwargs)
+ return (x * self.g, *rest)
+
+
+class ScaleNorm(nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps=1e-8):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
+
+
+class Residual(nn.Module):
+ def forward(self, x, residual):
+ return x + residual
+
+
+class GRUGating(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gru = nn.GRUCell(dim, dim)
+
+ def forward(self, x, residual):
+ gated_output = self.gru(
+ rearrange(x, 'b n d -> (b n) d'),
+ rearrange(residual, 'b n d -> (b n) d')
+ )
+
+ return gated_output.reshape_as(x)
+
+
+# feedforward
+
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# attention.
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head=DEFAULT_DIM_HEAD,
+ heads=8,
+ causal=False,
+ mask=None,
+ talking_heads=False,
+ sparse_topk=None,
+ use_entmax15=False,
+ num_mem_kv=0,
+ dropout=0.,
+ on_attn=False
+ ):
+ super().__init__()
+ if use_entmax15:
+ raise NotImplementedError("Check out entmax activation instead of softmax activation!")
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ self.causal = causal
+ self.mask = mask
+
+ inner_dim = dim_head * heads
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
+ self.dropout = nn.Dropout(dropout)
+
+ # talking heads
+ self.talking_heads = talking_heads
+ if talking_heads:
+ self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+ self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
+
+ # explicit topk sparse attention
+ self.sparse_topk = sparse_topk
+
+ # entmax
+ #self.attn_fn = entmax15 if use_entmax15 else F.softmax
+ self.attn_fn = F.softmax
+
+ # add memory key / values
+ self.num_mem_kv = num_mem_kv
+ if num_mem_kv > 0:
+ self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+ self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
+
+ # attention on attention
+ self.attn_on_attn = on_attn
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ rel_pos=None,
+ sinusoidal_emb=None,
+ prev_attn=None,
+ mem=None
+ ):
+ b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
+ kv_input = default(context, x)
+
+ q_input = x
+ k_input = kv_input
+ v_input = kv_input
+
+ if exists(mem):
+ k_input = torch.cat((mem, k_input), dim=-2)
+ v_input = torch.cat((mem, v_input), dim=-2)
+
+ if exists(sinusoidal_emb):
+ # in shortformer, the query would start at a position offset depending on the past cached memory
+ offset = k_input.shape[-2] - q_input.shape[-2]
+ q_input = q_input + sinusoidal_emb(q_input, offset=offset)
+ k_input = k_input + sinusoidal_emb(k_input)
+
+ q = self.to_q(q_input)
+ k = self.to_k(k_input)
+ v = self.to_v(v_input)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
+
+ input_mask = None
+ if any(map(exists, (mask, context_mask))):
+ q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
+ k_mask = q_mask if not exists(context) else context_mask
+ k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
+ q_mask = rearrange(q_mask, 'b i -> b () i ()')
+ k_mask = rearrange(k_mask, 'b j -> b () () j')
+ input_mask = q_mask * k_mask
+
+ if self.num_mem_kv > 0:
+ mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
+ k = torch.cat((mem_k, k), dim=-2)
+ v = torch.cat((mem_v, v), dim=-2)
+ if exists(input_mask):
+ input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
+
+ dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
+ mask_value = max_neg_value(dots)
+
+ if exists(prev_attn):
+ dots = dots + prev_attn
+
+ pre_softmax_attn = dots
+
+ if talking_heads:
+ dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
+
+ if exists(rel_pos):
+ dots = rel_pos(dots)
+
+ if exists(input_mask):
+ dots.masked_fill_(~input_mask, mask_value)
+ del input_mask
+
+ if self.causal:
+ i, j = dots.shape[-2:]
+ r = torch.arange(i, device=device)
+ mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
+ mask = F.pad(mask, (j - i, 0), value=False)
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
+ top, _ = dots.topk(self.sparse_topk, dim=-1)
+ vk = top[..., -1].unsqueeze(-1).expand_as(dots)
+ mask = dots < vk
+ dots.masked_fill_(mask, mask_value)
+ del mask
+
+ attn = self.attn_fn(dots, dim=-1)
+ post_softmax_attn = attn
+
+ attn = self.dropout(attn)
+
+ if talking_heads:
+ attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+
+ intermediates = Intermediates(
+ pre_softmax_attn=pre_softmax_attn,
+ post_softmax_attn=post_softmax_attn
+ )
+
+ return self.to_out(out), intermediates
+
+
+class AttentionLayers(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth,
+ heads=8,
+ causal=False,
+ cross_attend=False,
+ only_cross=False,
+ use_scalenorm=False,
+ use_rmsnorm=False,
+ use_rezero=False,
+ rel_pos_num_buckets=32,
+ rel_pos_max_distance=128,
+ position_infused_attn=False,
+ custom_layers=None,
+ sandwich_coef=None,
+ par_ratio=None,
+ residual_attn=False,
+ cross_residual_attn=False,
+ macaron=False,
+ pre_norm=True,
+ gate_residual=False,
+ **kwargs
+ ):
+ super().__init__()
+ ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
+ attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
+
+ dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
+
+ self.dim = dim
+ self.depth = depth
+ self.layers = nn.ModuleList([])
+
+ self.has_pos_emb = position_infused_attn
+ self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
+ self.rotary_pos_emb = always(None)
+
+ assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
+ self.rel_pos = None
+
+ self.pre_norm = pre_norm
+
+ self.residual_attn = residual_attn
+ self.cross_residual_attn = cross_residual_attn
+
+ norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
+ norm_class = RMSNorm if use_rmsnorm else norm_class
+ norm_fn = partial(norm_class, dim)
+
+ norm_fn = nn.Identity if use_rezero else norm_fn
+ branch_fn = Rezero if use_rezero else None
+
+ if cross_attend and not only_cross:
+ default_block = ('a', 'c', 'f')
+ elif cross_attend and only_cross:
+ default_block = ('c', 'f')
+ else:
+ default_block = ('a', 'f')
+
+ if macaron:
+ default_block = ('f',) + default_block
+
+ if exists(custom_layers):
+ layer_types = custom_layers
+ elif exists(par_ratio):
+ par_depth = depth * len(default_block)
+ assert 1 < par_ratio <= par_depth, 'par ratio out of range'
+ default_block = tuple(filter(not_equals('f'), default_block))
+ par_attn = par_depth // par_ratio
+ depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
+ par_width = (depth_cut + depth_cut // par_attn) // par_attn
+ assert len(default_block) <= par_width, 'default block is too large for par_ratio'
+ par_block = default_block + ('f',) * (par_width - len(default_block))
+ par_head = par_block * par_attn
+ layer_types = par_head + ('f',) * (par_depth - len(par_head))
+ elif exists(sandwich_coef):
+ assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
+ layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
+ else:
+ layer_types = default_block * depth
+
+ self.layer_types = layer_types
+ self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
+
+ for layer_type in self.layer_types:
+ if layer_type == 'a':
+ layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
+ elif layer_type == 'c':
+ layer = Attention(dim, heads=heads, **attn_kwargs)
+ elif layer_type == 'f':
+ layer = FeedForward(dim, **ff_kwargs)
+ layer = layer if not macaron else Scale(0.5, layer)
+ else:
+ raise Exception(f'invalid layer type {layer_type}')
+
+ if isinstance(layer, Attention) and exists(branch_fn):
+ layer = branch_fn(layer)
+
+ if gate_residual:
+ residual_fn = GRUGating(dim)
+ else:
+ residual_fn = Residual()
+
+ self.layers.append(nn.ModuleList([
+ norm_fn(),
+ layer,
+ residual_fn
+ ]))
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ context_mask=None,
+ mems=None,
+ return_hiddens=False
+ ):
+ hiddens = []
+ intermediates = []
+ prev_attn = None
+ prev_cross_attn = None
+
+ mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
+
+ for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ is_last = ind == (len(self.layers) - 1)
+
+ if layer_type == 'a':
+ hiddens.append(x)
+ layer_mem = mems.pop(0)
+
+ residual = x
+
+ if self.pre_norm:
+ x = norm(x)
+
+ if layer_type == 'a':
+ out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
+ prev_attn=prev_attn, mem=layer_mem)
+ elif layer_type == 'c':
+ out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
+ elif layer_type == 'f':
+ out = block(x)
+
+ x = residual_fn(out, residual)
+
+ if layer_type in ('a', 'c'):
+ intermediates.append(inter)
+
+ if layer_type == 'a' and self.residual_attn:
+ prev_attn = inter.pre_softmax_attn
+ elif layer_type == 'c' and self.cross_residual_attn:
+ prev_cross_attn = inter.pre_softmax_attn
+
+ if not self.pre_norm and not is_last:
+ x = norm(x)
+
+ if return_hiddens:
+ intermediates = LayerIntermediates(
+ hiddens=hiddens,
+ attn_intermediates=intermediates
+ )
+
+ return x, intermediates
+
+ return x
+
+
+class Encoder(AttentionLayers):
+ def __init__(self, **kwargs):
+ assert 'causal' not in kwargs, 'cannot set causality on encoder'
+ super().__init__(causal=False, **kwargs)
+
+
+
+class TransformerWrapper(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_tokens,
+ max_seq_len,
+ attn_layers,
+ emb_dim=None,
+ max_mem_len=0.,
+ emb_dropout=0.,
+ num_memory_tokens=None,
+ tie_embedding=False,
+ use_pos_emb=True
+ ):
+ super().__init__()
+ assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
+
+ dim = attn_layers.dim
+ emb_dim = default(emb_dim, dim)
+
+ self.max_seq_len = max_seq_len
+ self.max_mem_len = max_mem_len
+ self.num_tokens = num_tokens
+
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
+ self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
+ self.emb_dropout = nn.Dropout(emb_dropout)
+
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+ self.attn_layers = attn_layers
+ self.norm = nn.LayerNorm(dim)
+
+ self.init_()
+
+ self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
+
+ # memory tokens (like [cls]) from Memory Transformers paper
+ num_memory_tokens = default(num_memory_tokens, 0)
+ self.num_memory_tokens = num_memory_tokens
+ if num_memory_tokens > 0:
+ self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
+
+ # let funnel encoder know number of memory tokens, if specified
+ if hasattr(attn_layers, 'num_memory_tokens'):
+ attn_layers.num_memory_tokens = num_memory_tokens
+
+ def init_(self):
+ nn.init.normal_(self.token_emb.weight, std=0.02)
+
+ def forward(
+ self,
+ x,
+ return_embeddings=False,
+ mask=None,
+ return_mems=False,
+ return_attn=False,
+ mems=None,
+ **kwargs
+ ):
+ b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
+ x = self.token_emb(x)
+ x += self.pos_emb(x)
+ x = self.emb_dropout(x)
+
+ x = self.project_emb(x)
+
+ if num_mem > 0:
+ mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
+ x = torch.cat((mem, x), dim=1)
+
+ # auto-handle masking after appending memory tokens
+ if exists(mask):
+ mask = F.pad(mask, (num_mem, 0), value=True)
+
+ x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
+ x = self.norm(x)
+
+ mem, x = x[:, :num_mem], x[:, num_mem:]
+
+ out = self.to_logits(x) if not return_embeddings else x
+
+ if return_mems:
+ hiddens = intermediates.hiddens
+ new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
+ new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
+ return out, new_mems
+
+ if return_attn:
+ attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
+ return out, attn_maps
+
+ return out
+
diff --git a/StableSR/ldm/util.py b/StableSR/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b1301a55396c445ecdb28cc444fa10fcbd06391
--- /dev/null
+++ b/StableSR/ldm/util.py
@@ -0,0 +1,211 @@
+import importlib
+
+import torch
+import numpy as np
+from collections import abc
+from einops import rearrange
+from functools import partial
+
+import multiprocessing as mp
+from threading import Thread
+from queue import Queue
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+def instantiate_from_config_sr(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(config.get("params", dict()))
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
+ # create dummy dataset instance
+
+ # run prefetching
+ if idx_to_fn:
+ res = func(data, worker_id=idx)
+ else:
+ res = func(data)
+ Q.put([idx, res])
+ Q.put("Done")
+
+
+def parallel_data_prefetch(
+ func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
+):
+ # if target_data_type not in ["ndarray", "list"]:
+ # raise ValueError(
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
+ # )
+ if isinstance(data, np.ndarray) and target_data_type == "list":
+ raise ValueError("list expected but function got ndarray.")
+ elif isinstance(data, abc.Iterable):
+ if isinstance(data, dict):
+ print(
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
+ )
+ data = list(data.values())
+ if target_data_type == "ndarray":
+ data = np.asarray(data)
+ else:
+ data = list(data)
+ else:
+ raise TypeError(
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
+ )
+
+ if cpu_intensive:
+ Q = mp.Queue(1000)
+ proc = mp.Process
+ else:
+ Q = Queue(1000)
+ proc = Thread
+ # spawn processes
+ if target_data_type == "ndarray":
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(np.array_split(data, n_proc))
+ ]
+ else:
+ step = (
+ int(len(data) / n_proc + 1)
+ if len(data) % n_proc != 0
+ else int(len(data) / n_proc)
+ )
+ arguments = [
+ [func, Q, part, i, use_worker_id]
+ for i, part in enumerate(
+ [data[i: i + step] for i in range(0, len(data), step)]
+ )
+ ]
+ processes = []
+ for i in range(n_proc):
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
+ processes += [p]
+
+ # start processes
+ print(f"Start prefetching...")
+ import time
+
+ start = time.time()
+ gather_res = [[] for _ in range(n_proc)]
+ try:
+ for p in processes:
+ p.start()
+
+ k = 0
+ while k < n_proc:
+ # get result
+ res = Q.get()
+ if res == "Done":
+ k += 1
+ else:
+ gather_res[res[0]] = res[1]
+
+ except Exception as e:
+ print("Exception: ", e)
+ for p in processes:
+ p.terminate()
+
+ raise e
+ finally:
+ for p in processes:
+ p.join()
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
+
+ if target_data_type == 'ndarray':
+ if not isinstance(gather_res[0], np.ndarray):
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
+
+ # order outputs
+ return np.concatenate(gather_res, axis=0)
+ elif target_data_type == 'list':
+ out = []
+ for r in gather_res:
+ out.extend(r)
+ return out
+ else:
+ return gather_res
diff --git a/StableSR/main.py b/StableSR/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a5b1b00c3f8c9b0df596bb961cda8882a61982
--- /dev/null
+++ b/StableSR/main.py
@@ -0,0 +1,743 @@
+import argparse, os, sys, datetime, glob, importlib, csv
+import numpy as np
+import time
+import torch
+import torchvision
+import pytorch_lightning as pl
+
+from packaging import version
+from omegaconf import OmegaConf
+from torch.utils.data import random_split, DataLoader, Dataset, Subset
+from functools import partial
+from PIL import Image
+
+from pytorch_lightning import seed_everything
+from pytorch_lightning.trainer import Trainer
+from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
+from pytorch_lightning.utilities.distributed import rank_zero_only
+# from pytorch_lightning.utilities.rank_zero import rank_zero_only
+from pytorch_lightning.utilities import rank_zero_info
+
+from ldm.data.base import Txt2ImgIterableBaseDataset
+from ldm.util import instantiate_from_config, instantiate_from_config_sr
+from pytorch_lightning.loggers import WandbLogger
+
+
+def get_parser(**parser_kwargs):
+ def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+ parser = argparse.ArgumentParser(**parser_kwargs)
+ parser.add_argument(
+ "-n",
+ "--name",
+ type=str,
+ const=True,
+ default="",
+ nargs="?",
+ help="postfix for logdir",
+ )
+ parser.add_argument(
+ "-r",
+ "--resume",
+ type=str,
+ const=True,
+ default="",
+ nargs="?",
+ help="resume from logdir or checkpoint in logdir",
+ )
+ parser.add_argument(
+ "-b",
+ "--base",
+ nargs="*",
+ metavar="base_config.yaml",
+ help="paths to base configs. Loaded from left-to-right. "
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
+ default=list(),
+ )
+ parser.add_argument(
+ "-t",
+ "--train",
+ type=str2bool,
+ const=True,
+ default=False,
+ nargs="?",
+ help="train",
+ )
+ parser.add_argument(
+ "--no-test",
+ type=str2bool,
+ const=True,
+ default=False,
+ nargs="?",
+ help="disable test",
+ )
+ parser.add_argument(
+ "-p",
+ "--project",
+ help="name of new or path to existing project"
+ )
+ parser.add_argument(
+ "-d",
+ "--debug",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="enable post-mortem debugging",
+ )
+ parser.add_argument(
+ "-s",
+ "--seed",
+ type=int,
+ default=23,
+ help="seed for seed_everything",
+ )
+ parser.add_argument(
+ "-f",
+ "--postfix",
+ type=str,
+ default="",
+ help="post-postfix for default name",
+ )
+ parser.add_argument(
+ "-l",
+ "--logdir",
+ type=str,
+ default="./logs",
+ help="directory for logging dat shit",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
+ )
+ return parser
+
+
+def nondefault_trainer_args(opt):
+ parser = argparse.ArgumentParser()
+ parser = Trainer.add_argparse_args(parser)
+ args = parser.parse_args([])
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
+
+
+class WrappedDataset(Dataset):
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
+
+ def __init__(self, dataset):
+ self.data = dataset
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+
+def worker_init_fn(_):
+ worker_info = torch.utils.data.get_worker_info()
+
+ dataset = worker_info.dataset
+ worker_id = worker_info.id
+
+ if isinstance(dataset, Txt2ImgIterableBaseDataset):
+ split_size = dataset.num_records // worker_info.num_workers
+ # reset num_records to the true number to retain reliable length information
+ dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
+ current_id = np.random.choice(len(np.random.get_state()[1]), 1)
+ return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
+ else:
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
+
+
+class DataModuleFromConfig(pl.LightningDataModule):
+ def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
+ wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
+ shuffle_val_dataloader=False):
+ super().__init__()
+ self.batch_size = batch_size
+ self.dataset_configs = dict()
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
+ self.use_worker_init_fn = use_worker_init_fn
+ if train is not None:
+ self.dataset_configs["train"] = train
+ self.train_dataloader = self._train_dataloader
+ if validation is not None:
+ self.dataset_configs["validation"] = validation
+ self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
+ if test is not None:
+ self.dataset_configs["test"] = test
+ self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
+ if predict is not None:
+ self.dataset_configs["predict"] = predict
+ self.predict_dataloader = self._predict_dataloader
+ self.wrap = wrap
+
+ def prepare_data(self):
+ for data_cfg in self.dataset_configs.values():
+ instantiate_from_config_sr(data_cfg)
+
+ def setup(self, stage=None):
+ self.datasets = dict(
+ (k, instantiate_from_config_sr(self.dataset_configs[k]))
+ for k in self.dataset_configs)
+ if self.wrap:
+ for k in self.datasets:
+ self.datasets[k] = WrappedDataset(self.datasets[k])
+
+ def _train_dataloader(self):
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
+ if is_iterable_dataset or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+ return DataLoader(self.datasets["train"], batch_size=self.batch_size,
+ num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
+ worker_init_fn=init_fn)
+
+ def _val_dataloader(self, shuffle=False):
+ if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+ return DataLoader(self.datasets["validation"],
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ worker_init_fn=init_fn,
+ shuffle=shuffle)
+
+ def _test_dataloader(self, shuffle=False):
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
+ if is_iterable_dataset or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+
+ # do not shuffle dataloader for iterable dataset
+ shuffle = shuffle and (not is_iterable_dataset)
+
+ return DataLoader(self.datasets["test"], batch_size=self.batch_size,
+ num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
+
+ def _predict_dataloader(self, shuffle=False):
+ if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+ return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
+ num_workers=self.num_workers, worker_init_fn=init_fn)
+
+
+class SetupCallback(Callback):
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
+ super().__init__()
+ self.resume = resume
+ self.now = now
+ self.logdir = logdir
+ self.ckptdir = ckptdir
+ self.cfgdir = cfgdir
+ self.config = config
+ self.lightning_config = lightning_config
+
+ def on_keyboard_interrupt(self, trainer, pl_module):
+ if trainer.global_rank == 0:
+ print("Summoning checkpoint.")
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
+ trainer.save_checkpoint(ckpt_path)
+
+ def on_pretrain_routine_start(self, trainer, pl_module):
+ if trainer.global_rank == 0:
+ # Create logdirs and save configs
+ os.makedirs(self.logdir, exist_ok=True)
+ os.makedirs(self.ckptdir, exist_ok=True)
+ os.makedirs(self.cfgdir, exist_ok=True)
+
+ if "callbacks" in self.lightning_config:
+ if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
+ os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
+ print("Project config")
+ print(OmegaConf.to_yaml(self.config))
+ OmegaConf.save(self.config,
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
+
+ print("Lightning config")
+ print(OmegaConf.to_yaml(self.lightning_config))
+ OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
+
+ else:
+ # ModelCheckpoint callback created log directory --- remove it
+ if not self.resume and os.path.exists(self.logdir):
+ dst, name = os.path.split(self.logdir)
+ dst = os.path.join(dst, "child_runs", name)
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
+ try:
+ os.rename(self.logdir, dst)
+ except FileNotFoundError:
+ pass
+
+
+class ImageLogger(Callback):
+ def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
+ log_images_kwargs=None):
+ super().__init__()
+ self.rescale = rescale
+ self.batch_freq = batch_frequency
+ self.max_images = max_images
+ self.logger_log_images = {
+ pl.loggers.TestTubeLogger: self._testtube,
+ }
+ self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
+ if not increase_log_steps:
+ self.log_steps = [self.batch_freq]
+ self.clamp = clamp
+ self.disabled = disabled
+ self.log_on_batch_idx = log_on_batch_idx
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+ self.log_first_step = log_first_step
+
+ @rank_zero_only
+ def _testtube(self, pl_module, images, batch_idx, split):
+ for k in images:
+ grid = torchvision.utils.make_grid(images[k])
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+
+ tag = f"{split}/{k}"
+ pl_module.logger.experiment.add_image(
+ tag, grid,
+ global_step=pl_module.global_step)
+
+ @rank_zero_only
+ def log_local(self, save_dir, split, images,
+ global_step, current_epoch, batch_idx):
+ root = os.path.join(save_dir, "images", split)
+ for k in images:
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
+ if self.rescale:
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid * 255).astype(np.uint8)
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
+ k,
+ global_step,
+ current_epoch,
+ batch_idx)
+ path = os.path.join(root, filename)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ Image.fromarray(grid).save(path)
+
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
+ hasattr(pl_module, "log_images") and
+ callable(pl_module.log_images) and
+ self.max_images > 0):
+ logger = type(pl_module.logger)
+
+ is_train = pl_module.training
+ if is_train:
+ pl_module.eval()
+
+ with torch.no_grad():
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
+
+ for k in images:
+ N = min(images[k].shape[0], self.max_images)
+ images[k] = images[k][:N]
+ if isinstance(images[k], torch.Tensor):
+ images[k] = images[k].detach().cpu()
+ if self.clamp:
+ images[k] = torch.clamp(images[k], -1., 1.)
+
+ self.log_local(pl_module.logger.save_dir, split, images,
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
+
+ logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
+ logger_log_images(pl_module, images, pl_module.global_step, split)
+
+ if is_train:
+ pl_module.train()
+
+ def check_frequency(self, check_idx):
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
+ check_idx > 0 or self.log_first_step):
+ try:
+ self.log_steps.pop(0)
+ except IndexError as e:
+ print(e)
+ pass
+ return True
+ return False
+
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
+ self.log_img(pl_module, batch, batch_idx, split="train")
+
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+ if not self.disabled and pl_module.global_step > 0:
+ self.log_img(pl_module, batch, batch_idx, split="val")
+ if hasattr(pl_module, 'calibrate_grad_norm'):
+ if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
+
+
+class CUDACallback(Callback):
+ # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
+ def on_train_epoch_start(self, trainer, pl_module):
+ # Reset the memory use counter
+ torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
+ torch.cuda.synchronize(trainer.root_gpu)
+ self.start_time = time.time()
+
+ def on_train_epoch_end(self, trainer, pl_module, outputs):
+ torch.cuda.synchronize(trainer.root_gpu)
+ max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
+ epoch_time = time.time() - self.start_time
+
+ try:
+ max_memory = trainer.training_type_plugin.reduce(max_memory)
+ epoch_time = trainer.training_type_plugin.reduce(epoch_time)
+
+ rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
+ rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
+ except AttributeError:
+ pass
+
+
+if __name__ == "__main__":
+ from collections import OrderedDict
+ # custom parser to specify config files, train, test and debug mode,
+ # postfix, resume.
+ # `--key value` arguments are interpreted as arguments to the trainer.
+ # `nested.key=value` arguments are interpreted as config parameters.
+ # configs are merged from left-to-right followed by command line parameters.
+
+ # model:
+ # base_learning_rate: float
+ # target: path to lightning module
+ # params:
+ # key: value
+ # data:
+ # target: main.DataModuleFromConfig
+ # params:
+ # batch_size: int
+ # wrap: bool
+ # train:
+ # target: path to train dataset
+ # params:
+ # key: value
+ # validation:
+ # target: path to validation dataset
+ # params:
+ # key: value
+ # test:
+ # target: path to test dataset
+ # params:
+ # key: value
+ # lightning: (optional, has sane defaults and can be specified on cmdline)
+ # trainer:
+ # additional arguments to trainer
+ # logger:
+ # logger to instantiate
+ # modelcheckpoint:
+ # modelcheckpoint to instantiate
+ # callbacks:
+ # callback1:
+ # target: importpath
+ # params:
+ # key: value
+
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+
+ # add cwd for convenience and to make classes in this file available when
+ # running as `python main.py`
+ # (in particular `main.DataModuleFromConfig`)
+ sys.path.append(os.getcwd())
+
+ parser = get_parser()
+ parser = Trainer.add_argparse_args(parser)
+
+ opt, unknown = parser.parse_known_args()
+ if opt.name and opt.resume:
+ raise ValueError(
+ "-n/--name and -r/--resume cannot be specified both."
+ "If you want to resume training in a new log folder, "
+ "use -n/--name in combination with --resume_from_checkpoint"
+ )
+ if opt.resume:
+ if not os.path.exists(opt.resume):
+ raise ValueError("Cannot find {}".format(opt.resume))
+ if os.path.isfile(opt.resume):
+ paths = opt.resume.split("/")
+ # idx = len(paths)-paths[::-1].index("logs")+1
+ # logdir = "/".join(paths[:idx])
+ logdir = "/".join(paths[:-2])
+ ckpt = opt.resume
+ else:
+ assert os.path.isdir(opt.resume), opt.resume
+ logdir = opt.resume.rstrip("/")
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
+
+ # delete JPEGer
+ # state_dict = torch.load(ckpt)
+ # new_state_dict = OrderedDict()
+ # for k, v in state_dict['state_dict'].items():
+ # if 'jpeger' not in k or 'usm_sharpener' not in k:
+ # new_state_dict[k] = v
+ # if new_state_dict != state_dict['state_dict']:
+ # state_dict['state_dict'] = new_state_dict
+ # torch.save(state_dict, ckpt)
+ # del new_state_dict
+ # del state_dict
+
+ opt.resume_from_checkpoint = ckpt
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
+ opt.base = base_configs + opt.base
+ _tmp = logdir.split("/")
+ nowname = _tmp[-1]
+ else:
+ if opt.name:
+ name = "_" + opt.name
+ elif opt.base:
+ cfg_fname = os.path.split(opt.base[0])[-1]
+ cfg_name = os.path.splitext(cfg_fname)[0]
+ name = "_" + cfg_name
+ else:
+ name = ""
+ nowname = now + name + opt.postfix
+ logdir = os.path.join(opt.logdir, nowname)
+
+ ckptdir = os.path.join(logdir, "checkpoints")
+ cfgdir = os.path.join(logdir, "configs")
+ seed_everything(opt.seed)
+
+ # try:
+ # init and save configs
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
+ cli = OmegaConf.from_dotlist(unknown)
+ config = OmegaConf.merge(*configs, cli)
+ lightning_config = config.pop("lightning", OmegaConf.create())
+ # merge trainer cli with config
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
+ # default to ddp
+ trainer_config["accelerator"] = "ddp"
+ for k in nondefault_trainer_args(opt):
+ trainer_config[k] = getattr(opt, k)
+ if not "gpus" in trainer_config:
+ del trainer_config["accelerator"]
+ cpu = True
+ else:
+ gpuinfo = trainer_config["gpus"]
+ print(f"Running on GPUs {gpuinfo}")
+ cpu = False
+ trainer_opt = argparse.Namespace(**trainer_config)
+ lightning_config.trainer = trainer_config
+
+ # model
+ model = instantiate_from_config(config.model)
+
+ model.configs = config
+
+ # trainer and callbacks
+ trainer_kwargs = dict()
+
+ # default logger configs
+ default_logger_cfgs = {
+ "wandb": {
+ "target": "pytorch_lightning.loggers.WandbLogger",
+ "params": {
+ "name": nowname,
+ "save_dir": logdir,
+ "offline": opt.debug,
+ "id": nowname,
+ }
+ },
+ "testtube": {
+ "target": "pytorch_lightning.loggers.TestTubeLogger",
+ "params": {
+ "name": "testtube",
+ "save_dir": logdir,
+ }
+ },
+ }
+ # We use wandb by default. Change to testtube if you do not want to use wandb
+ default_logger_cfg = default_logger_cfgs["wandb"]
+ os.makedirs(os.path.join(logdir, 'wandb'), exist_ok=True)
+ if "logger" in lightning_config:
+ logger_cfg = lightning_config.logger
+ else:
+ logger_cfg = OmegaConf.create()
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
+
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
+ # specify which metric is used to determine best models
+ default_modelckpt_cfg = {
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
+ "params": {
+ "dirpath": ckptdir,
+ "filename": "{epoch:06}",
+ "verbose": True,
+ "save_last": True,
+ }
+ }
+ if hasattr(model, "monitor"):
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
+ default_modelckpt_cfg["params"]["save_top_k"] = 20
+
+ if "modelcheckpoint" in lightning_config:
+ modelckpt_cfg = lightning_config.modelcheckpoint
+ else:
+ modelckpt_cfg = OmegaConf.create()
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
+ if version.parse(pl.__version__) < version.parse('1.4.0'):
+ trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
+
+ # add callback which sets up log directory
+ default_callbacks_cfg = {
+ "setup_callback": {
+ "target": "main.SetupCallback",
+ "params": {
+ "resume": opt.resume,
+ "now": now,
+ "logdir": logdir,
+ "ckptdir": ckptdir,
+ "cfgdir": cfgdir,
+ "config": config,
+ "lightning_config": lightning_config,
+ }
+ },
+ "image_logger": {
+ "target": "main.ImageLogger",
+ "params": {
+ "batch_frequency": 750,
+ "max_images": 4,
+ "clamp": True
+ }
+ },
+ "learning_rate_logger": {
+ "target": "main.LearningRateMonitor",
+ "params": {
+ "logging_interval": "step",
+ # "log_momentum": True
+ }
+ },
+ "cuda_callback": {
+ "target": "main.CUDACallback"
+ },
+ }
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
+ default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
+
+ if "callbacks" in lightning_config:
+ callbacks_cfg = lightning_config.callbacks
+ else:
+ callbacks_cfg = OmegaConf.create()
+
+ if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
+ print(
+ 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
+ default_metrics_over_trainsteps_ckpt_dict = {
+ 'metrics_over_trainsteps_checkpoint':
+ {"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
+ 'params': {
+ "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
+ "filename": "{epoch:06}-{step:09}",
+ "verbose": True,
+ 'save_top_k': -1,
+ 'every_n_train_steps': 10000,
+ 'save_weights_only': True
+ }
+ }
+ }
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
+
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
+ if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
+ callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
+ elif 'ignore_keys_callback' in callbacks_cfg:
+ del callbacks_cfg['ignore_keys_callback']
+
+ trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
+
+ trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
+ trainer.logdir = logdir ###
+
+ # data
+ data = instantiate_from_config(config.data)
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
+ # calling these ourselves should not be necessary but it is.
+ # lightning still takes care of proper multiprocessing though
+ data.prepare_data()
+ data.setup()
+ print("#### Data #####")
+ for k in data.datasets:
+ print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
+
+ # configure learning rate
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
+ if not cpu:
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
+ else:
+ ngpu = 1
+ if 'accumulate_grad_batches' in lightning_config.trainer:
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
+ else:
+ accumulate_grad_batches = 1
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
+ if opt.scale_lr:
+ model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
+ print(
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
+ model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
+ else:
+ model.learning_rate = base_lr
+ print("++++ NOT USING LR SCALING ++++")
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
+
+
+ # allow checkpointing via USR1
+ def melk(*args, **kwargs):
+ # run all checkpoint hooks
+ if trainer.global_rank == 0:
+ print("Summoning checkpoint.")
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
+ trainer.save_checkpoint(ckpt_path)
+
+
+ def divein(*args, **kwargs):
+ if trainer.global_rank == 0:
+ import pudb;
+ pudb.set_trace()
+
+
+ import signal
+
+ signal.signal(signal.SIGUSR1, melk)
+ signal.signal(signal.SIGUSR2, divein)
+
+ # run
+ if opt.train:
+ try:
+ trainer.fit(model, data)
+ except Exception:
+ melk()
+ raise
+ if not opt.no_test and not trainer.interrupted:
+ trainer.test(model, data)
diff --git a/StableSR/predict.py b/StableSR/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4377a85e6931123225cca1eaf88a921d38b352f
--- /dev/null
+++ b/StableSR/predict.py
@@ -0,0 +1,280 @@
+# Prediction interface for Cog ⚙️
+# https://github.com/replicate/cog/blob/main/docs/python.md
+
+import os
+import PIL
+import numpy as np
+import copy
+import torch
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import trange
+from itertools import islice
+from einops import rearrange, repeat
+from torch import autocast
+from pytorch_lightning import seed_everything
+import torch.nn.functional as F
+
+from ldm.util import instantiate_from_config
+from scripts.wavelet_color_fix import (
+ wavelet_reconstruction,
+ adaptive_instance_normalization,
+)
+
+from cog import BasePredictor, Input, Path
+
+
+class Predictor(BasePredictor):
+ def setup(self) -> None:
+ """Load the model into memory to make running multiple predictions efficient"""
+ config = OmegaConf.load("configs/stableSRNew/v2-finetune_text_T_512.yaml")
+ self.model = load_model_from_config(config, "stablesr_000117.ckpt")
+ device = torch.device("cuda")
+
+ self.model.configs = config
+ self.model = self.model.to(device)
+
+ vqgan_config = OmegaConf.load(
+ "configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml"
+ )
+ self.vq_model = load_model_from_config(vqgan_config, "vqgan_cfw_00011.ckpt")
+ self.vq_model = self.vq_model.to(device)
+
+ def predict(
+ self,
+ input_image: Path = Input(description="Input image"),
+ ddpm_steps: int = Input(
+ description="Number of DDPM steps for sampling", default=200
+ ),
+ fidelity_weight: float = Input(
+ description="Balance the quality (lower number) and fidelity (higher number)",
+ default=0.5,
+ ),
+ upscale: float = Input(
+ description="The upscale for super-resolution, 4x SR by default",
+ default=4.0,
+ ),
+ tile_overlap: int = Input(
+ description="The overlap between tiles, betwwen 0 to 64",
+ ge=0,
+ le=64,
+ default=32,
+ ),
+ colorfix_type: str = Input(
+ choices=["adain", "wavelet", "none"], default="adain"
+ ),
+ seed: int = Input(
+ description="Random seed. Leave blank to randomize the seed", default=None
+ ),
+ ) -> Path:
+ """Run a single prediction on the model"""
+ if seed is None:
+ seed = int.from_bytes(os.urandom(2), "big")
+ print(f"Using seed: {seed}")
+
+ self.vq_model.decoder.fusion_w = fidelity_weight
+
+ seed_everything(seed)
+
+ n_samples = 1
+ device = torch.device("cuda")
+
+ cur_image = load_img(str(input_image)).to(device)
+ cur_image = F.interpolate(
+ cur_image,
+ size=(int(cur_image.size(-2) * upscale), int(cur_image.size(-1) * upscale)),
+ mode="bicubic",
+ )
+
+ self.model.register_schedule(
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=0.00085,
+ linear_end=0.0120,
+ cosine_s=8e-3,
+ )
+ self.model.num_timesteps = 1000
+
+ sqrt_alphas_cumprod = copy.deepcopy(self.model.sqrt_alphas_cumprod)
+ sqrt_one_minus_alphas_cumprod = copy.deepcopy(
+ self.model.sqrt_one_minus_alphas_cumprod
+ )
+
+ use_timesteps = set(space_timesteps(1000, [ddpm_steps]))
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ timestep_map = []
+ for i, alpha_cumprod in enumerate(self.model.alphas_cumprod):
+ if i in use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ timestep_map.append(i)
+ new_betas = [beta.data.cpu().numpy() for beta in new_betas]
+ self.model.register_schedule(
+ given_betas=np.array(new_betas), timesteps=len(new_betas)
+ )
+ self.model.num_timesteps = 1000
+ self.model.ori_timesteps = list(use_timesteps)
+ self.model.ori_timesteps.sort()
+ self.model = self.model.to(device)
+
+ precision_scope = autocast
+ input_size = 512
+
+ output = "/tmp/out.png"
+
+ with torch.no_grad():
+ with precision_scope("cuda"):
+ with self.model.ema_scope():
+ init_image = cur_image
+ init_image = init_image.clamp(-1.0, 1.0)
+ ori_size = None
+
+ print(init_image.size())
+
+ if (
+ init_image.size(-1) < input_size
+ or init_image.size(-2) < input_size
+ ):
+ ori_size = init_image.size()
+ new_h = max(ori_size[-2], input_size)
+ new_w = max(ori_size[-1], input_size)
+ init_template = torch.zeros(
+ 1, init_image.size(1), new_h, new_w
+ ).to(init_image.device)
+ init_template[:, :, : ori_size[-2], : ori_size[-1]] = init_image
+ else:
+ init_template = init_image
+
+ init_latent = self.model.get_first_stage_encoding(
+ self.model.encode_first_stage(init_template)
+ ) # move to latent space
+ text_init = [""] * n_samples
+ semantic_c = self.model.cond_stage_model(text_init)
+
+ noise = torch.randn_like(init_latent)
+ # If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
+ t = repeat(torch.tensor([999]), "1 -> b", b=init_image.size(0))
+ t = t.to(device).long()
+ x_T = self.model.q_sample_respace(
+ x_start=init_latent,
+ t=t,
+ sqrt_alphas_cumprod=sqrt_alphas_cumprod,
+ sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod,
+ noise=noise,
+ )
+ samples, _ = self.model.sample_canvas(
+ cond=semantic_c,
+ struct_cond=init_latent,
+ batch_size=init_image.size(0),
+ timesteps=ddpm_steps,
+ time_replace=ddpm_steps,
+ x_T=x_T,
+ return_intermediates=True,
+ tile_size=int(input_size / 8),
+ tile_overlap=tile_overlap,
+ batch_size_sample=n_samples,
+ )
+ _, enc_fea_lq = self.vq_model.encode(init_template)
+ x_samples = self.vq_model.decode(
+ samples * 1.0 / self.model.scale_factor, enc_fea_lq
+ )
+ if ori_size is not None:
+ x_samples = x_samples[:, :, : ori_size[-2], : ori_size[-1]]
+ if colorfix_type == "adain":
+ x_samples = adaptive_instance_normalization(
+ x_samples, init_image
+ )
+ elif colorfix_type == "wavelet":
+ x_samples = wavelet_reconstruction(x_samples, init_image)
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for i in range(init_image.size(0)):
+ x_sample = 255.0 * rearrange(
+ x_samples[i].cpu().numpy(), "c h w -> h w c"
+ )
+ Image.fromarray(x_sample.astype(np.uint8)).save(output)
+
+ return Path(output)
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+
+def read_image(im_path):
+ im = np.array(Image.open(im_path).convert("RGB"))
+ im = im.astype(np.float32) / 255.0
+ im = im[None].transpose(0, 3, 1, 2)
+ im = (torch.from_numpy(im) - 0.5) / 0.5
+
+ return im.cuda()
+
+
+def space_timesteps(num_timesteps, section_counts):
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")] # [250,]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_img(path):
+ image = Image.open(path).convert("RGB")
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h}) from {path}")
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
diff --git a/StableSR/scripts/sr_val_ddpm_text_T_vqganfin_old.py b/StableSR/scripts/sr_val_ddpm_text_T_vqganfin_old.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c184bf85cd3cf32c6619c7ed0b7649cfdf62b84
--- /dev/null
+++ b/StableSR/scripts/sr_val_ddpm_text_T_vqganfin_old.py
@@ -0,0 +1,318 @@
+"""make variations of input image"""
+
+import argparse, os, sys, glob
+import PIL
+import torch
+import numpy as np
+import torchvision
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from torch import autocast
+from contextlib import nullcontext
+import time
+from pytorch_lightning import seed_everything
+
+from ldm.util import instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+import math
+import copy
+from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim"):])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")] #[250,]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+def load_img(path):
+ image = Image.open(path).convert("RGB")
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h}) from {path}")
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.*image - 1.
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--init-img",
+ type=str,
+ nargs="?",
+ help="path to the input image",
+ default="inputs/user_upload",
+ )
+ parser.add_argument(
+ "--outdir",
+ type=str,
+ nargs="?",
+ help="dir to write results to",
+ default="outputs/user_upload",
+ )
+ parser.add_argument(
+ "--ddpm_steps",
+ type=int,
+ default=1000,
+ help="number of ddpm sampling steps",
+ )
+ parser.add_argument(
+ "--C",
+ type=int,
+ default=4,
+ help="latent channels",
+ )
+ parser.add_argument(
+ "--f",
+ type=int,
+ default=8,
+ help="downsampling factor, most often 8 or 16",
+ )
+ parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=2,
+ help="how many samples to produce for each given prompt. A.k.a batch size",
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/stableSRNew/v2-finetune_text_T_512.yaml",
+ help="path to config which constructs model",
+ )
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="models/ldm/stable-diffusion-v1/model.ckpt",
+ help="path to checkpoint of model",
+ )
+ parser.add_argument(
+ "--vqgan_ckpt",
+ type=str,
+ default="models/ldm/stable-diffusion-v1/epoch=000011.ckpt",
+ help="path to checkpoint of VQGAN model",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="the seed (for reproducible sampling)",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ help="evaluate at this precision",
+ choices=["full", "autocast"],
+ default="autocast"
+ )
+ parser.add_argument(
+ "--input_size",
+ type=int,
+ default=512,
+ help="input size",
+ )
+ parser.add_argument(
+ "--dec_w",
+ type=float,
+ default=0.5,
+ help="weight for combining VQGAN and Diffusion",
+ )
+ parser.add_argument(
+ "--colorfix_type",
+ type=str,
+ default="nofix",
+ help="Color fix type to adjust the color of HR result according to LR input: adain (used in paper); wavelet; nofix",
+ )
+
+ opt = parser.parse_args()
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+ print('>>>>>>>>>>color correction>>>>>>>>>>>')
+ if opt.colorfix_type == 'adain':
+ print('Use adain color correction')
+ elif opt.colorfix_type == 'wavelet':
+ print('Use wavelet color correction')
+ else:
+ print('No color correction')
+ print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
+
+ vqgan_config = OmegaConf.load("configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
+ vq_model = load_model_from_config(vqgan_config, opt.vqgan_ckpt)
+ vq_model = vq_model.to(device)
+ vq_model.decoder.fusion_w = opt.dec_w
+
+ seed_everything(opt.seed)
+
+ transform = torchvision.transforms.Compose([
+ torchvision.transforms.Resize(opt.input_size),
+ torchvision.transforms.CenterCrop(opt.input_size),
+ ])
+
+ config = OmegaConf.load(f"{opt.config}")
+ model = load_model_from_config(config, f"{opt.ckpt}")
+ model = model.to(device)
+
+ os.makedirs(opt.outdir, exist_ok=True)
+ outpath = opt.outdir
+
+ batch_size = opt.n_samples
+
+ img_list_ori = os.listdir(opt.init_img)
+ img_list = copy.deepcopy(img_list_ori)
+ init_image_list = []
+ for item in img_list_ori:
+ if os.path.exists(os.path.join(outpath, item)):
+ img_list.remove(item)
+ continue
+ cur_image = load_img(os.path.join(opt.init_img, item)).to(device)
+ cur_image = transform(cur_image)
+ cur_image = cur_image.clamp(-1, 1)
+ init_image_list.append(cur_image)
+ init_image_list = torch.cat(init_image_list, dim=0)
+ niters = math.ceil(init_image_list.size(0) / batch_size)
+ init_image_list = init_image_list.chunk(niters)
+
+ model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
+ model.num_timesteps = 1000
+
+ sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
+ sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)
+
+ use_timesteps = set(space_timesteps(1000, [opt.ddpm_steps]))
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ timestep_map = []
+ for i, alpha_cumprod in enumerate(model.alphas_cumprod):
+ if i in use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ timestep_map.append(i)
+ new_betas = [beta.data.cpu().numpy() for beta in new_betas]
+ model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
+ model.num_timesteps = 1000
+ model.ori_timesteps = list(use_timesteps)
+ model.ori_timesteps.sort()
+ model = model.to(device)
+
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
+ niqe_list = []
+ with torch.no_grad():
+ with precision_scope("cuda"):
+ with model.ema_scope():
+ tic = time.time()
+ all_samples = list()
+ for n in trange(niters, desc="Sampling"):
+ init_image = init_image_list[n]
+ init_latent_generator, enc_fea_lq = vq_model.encode(init_image)
+ init_latent = model.get_first_stage_encoding(init_latent_generator)
+ text_init = ['']*init_image.size(0)
+ semantic_c = model.cond_stage_model(text_init)
+
+ noise = torch.randn_like(init_latent)
+ # If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
+ t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0))
+ t = t.to(device).long()
+ x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
+ x_T = None
+
+ samples, _ = model.sample(cond=semantic_c, struct_cond=init_latent, batch_size=init_image.size(0), timesteps=opt.ddpm_steps, time_replace=opt.ddpm_steps, x_T=x_T, return_intermediates=True)
+ x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
+ if opt.colorfix_type == 'adain':
+ x_samples = adaptive_instance_normalization(x_samples, init_image)
+ elif opt.colorfix_type == 'wavelet':
+ x_samples = wavelet_reconstruction(x_samples, init_image)
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for i in range(init_image.size(0)):
+ img_name = img_list.pop(0)
+ basename = os.path.splitext(os.path.basename(img_name))[0]
+ x_sample = 255. * rearrange(x_samples[i].cpu().numpy(), 'c h w -> h w c')
+ Image.fromarray(x_sample.astype(np.uint8)).save(
+ os.path.join(outpath, basename+'.png'))
+
+ toc = time.time()
+
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
+ f" \nEnjoy.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/StableSR/scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas.py b/StableSR/scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas.py
new file mode 100644
index 0000000000000000000000000000000000000000..6429a97c2d82c03d93985ac2de970dc7360da03a
--- /dev/null
+++ b/StableSR/scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas.py
@@ -0,0 +1,351 @@
+"""make variations of input image"""
+
+import argparse, os, sys, glob
+import PIL
+import torch
+import numpy as np
+import torchvision
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from torch import autocast
+from contextlib import nullcontext
+import time
+from pytorch_lightning import seed_everything
+
+from ldm.util import instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+import math
+import copy
+import torch.nn.functional as F
+import cv2
+from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim"):])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")] #[250,]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+def load_img(path):
+ image = Image.open(path).convert("RGB")
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h}) from {path}")
+ w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.*image - 1.
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--init-img",
+ type=str,
+ nargs="?",
+ help="path to the input image",
+ default="inputs/user_upload"
+ )
+ parser.add_argument(
+ "--outdir",
+ type=str,
+ nargs="?",
+ help="dir to write results to",
+ default="outputs/user_upload"
+ )
+ parser.add_argument(
+ "--ddpm_steps",
+ type=int,
+ default=1000,
+ help="number of ddpm sampling steps",
+ )
+ parser.add_argument(
+ "--C",
+ type=int,
+ default=4,
+ help="latent channels",
+ )
+ parser.add_argument(
+ "--f",
+ type=int,
+ default=8,
+ help="downsampling factor, most often 8 or 16",
+ )
+ parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=2,
+ help="how many samples to produce for each given prompt. A.k.a batch size",
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/stableSRNew/v2-finetune_text_T_512.yaml",
+ help="path to config which constructs model",
+ )
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="./stablesr_000117.ckpt",
+ help="path to checkpoint of model",
+ )
+ parser.add_argument(
+ "--vqgan_ckpt",
+ type=str,
+ default="./vqgan_cfw_00011.ckpt",
+ help="path to checkpoint of VQGAN model",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="the seed (for reproducible sampling)",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ help="evaluate at this precision",
+ choices=["full", "autocast"],
+ default="autocast"
+ )
+ parser.add_argument(
+ "--input_size",
+ type=int,
+ default=512,
+ help="input size",
+ )
+ parser.add_argument(
+ "--dec_w",
+ type=float,
+ default=0.5,
+ help="weight for combining VQGAN and Diffusion",
+ )
+ parser.add_argument(
+ "--tile_overlap",
+ type=int,
+ default=32,
+ help="tile overlap size",
+ )
+ parser.add_argument(
+ "--upscale",
+ type=float,
+ default=4.0,
+ help="upsample scale",
+ )
+ parser.add_argument(
+ "--colorfix_type",
+ type=str,
+ default="nofix",
+ help="Color fix type to adjust the color of HR result according to LR input: adain (used in paper); wavelet; nofix",
+ )
+
+ opt = parser.parse_args()
+ seed_everything(opt.seed)
+
+ print('>>>>>>>>>>color correction>>>>>>>>>>>')
+ if opt.colorfix_type == 'adain':
+ print('Use adain color correction')
+ elif opt.colorfix_type == 'wavelet':
+ print('Use wavelet color correction')
+ else:
+ print('No color correction')
+ print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
+
+ config = OmegaConf.load(f"{opt.config}")
+ model = load_model_from_config(config, f"{opt.ckpt}")
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+
+ model.configs = config
+
+ vqgan_config = OmegaConf.load("configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
+ vq_model = load_model_from_config(vqgan_config, opt.vqgan_ckpt)
+ vq_model = vq_model.to(device)
+ vq_model.decoder.fusion_w = opt.dec_w
+
+ os.makedirs(opt.outdir, exist_ok=True)
+ outpath = opt.outdir
+
+ batch_size = opt.n_samples
+
+ img_list_ori = os.listdir(opt.init_img)
+ img_list = copy.deepcopy(img_list_ori)
+ init_image_list = []
+ for item in img_list_ori:
+ if os.path.exists(os.path.join(outpath, item)):
+ img_list.remove(item)
+ continue
+ cur_image = load_img(os.path.join(opt.init_img, item)).to(device)
+ # max size: 1800 x 1800 for V100
+ cur_image = F.interpolate(
+ cur_image,
+ size=(int(cur_image.size(-2)*opt.upscale),
+ int(cur_image.size(-1)*opt.upscale)),
+ mode='bicubic',
+ )
+ init_image_list.append(cur_image)
+
+ model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
+ model.num_timesteps = 1000
+
+ sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
+ sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)
+
+ use_timesteps = set(space_timesteps(1000, [opt.ddpm_steps]))
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ timestep_map = []
+ for i, alpha_cumprod in enumerate(model.alphas_cumprod):
+ if i in use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ timestep_map.append(i)
+ new_betas = [beta.data.cpu().numpy() for beta in new_betas]
+ model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
+ model.num_timesteps = 1000
+ model.ori_timesteps = list(use_timesteps)
+ model.ori_timesteps.sort()
+ model = model.to(device)
+
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
+ with torch.no_grad():
+ with precision_scope("cuda"):
+ with model.ema_scope():
+ tic = time.time()
+ all_samples = list()
+ for n in trange(len(init_image_list), desc="Sampling"):
+ init_image = init_image_list[n]
+ init_image = init_image.clamp(-1.0, 1.0)
+ ori_size = None
+
+ print('>>>>>>>>>>>>>>>>>>>>>>>')
+ print(init_image.size())
+
+ if init_image.size(-1) < opt.input_size or init_image.size(-2) < opt.input_size:
+ ori_size = init_image.size()
+ new_h = max(ori_size[-2], opt.input_size)
+ new_w = max(ori_size[-1], opt.input_size)
+ init_template = torch.zeros(1, init_image.size(1), new_h, new_w).to(init_image.device)
+ init_template[:, :, :ori_size[-2], :ori_size[-1]] = init_image
+ else:
+ init_template = init_image
+
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_template)) # move to latent space
+ text_init = ['']*opt.n_samples
+ semantic_c = model.cond_stage_model(text_init)
+
+ noise = torch.randn_like(init_latent)
+ # If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
+ t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0))
+ t = t.to(device).long()
+ x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
+ # x_T = noise
+
+ samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_image.size(0), timesteps=opt.ddpm_steps, time_replace=opt.ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(opt.input_size/8), tile_overlap=opt.tile_overlap, batch_size_sample=opt.n_samples)
+ _, enc_fea_lq = vq_model.encode(init_template)
+ x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
+ if ori_size is not None:
+ x_samples = x_samples[:, :, :ori_size[-2], :ori_size[-1]]
+ if opt.colorfix_type == 'adain':
+ x_samples = adaptive_instance_normalization(x_samples, init_image)
+ elif opt.colorfix_type == 'wavelet':
+ x_samples = wavelet_reconstruction(x_samples, init_image)
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for i in range(init_image.size(0)):
+ img_name = img_list.pop(0)
+ basename = os.path.splitext(os.path.basename(img_name))[0]
+ x_sample = 255. * rearrange(x_samples[i].cpu().numpy(), 'c h w -> h w c')
+ Image.fromarray(x_sample.astype(np.uint8)).save(
+ os.path.join(outpath, basename+'.png'))
+ init_image = torch.clamp((init_image + 1.0) / 2.0, min=0.0, max=1.0)
+ init_image = 255. * rearrange(init_image[i].cpu().numpy(), 'c h w -> h w c')
+ Image.fromarray(init_image.astype(np.uint8)).save(
+ os.path.join(outpath, basename+'_lq.png'))
+
+ toc = time.time()
+
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
+ f" \nEnjoy.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/StableSR/scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas_tile.py b/StableSR/scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas_tile.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8d0671f9c059edb00a32773d6a5fe9deb1014d9
--- /dev/null
+++ b/StableSR/scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas_tile.py
@@ -0,0 +1,422 @@
+"""make variations of input image"""
+
+import argparse, os, sys, glob
+import PIL
+import torch
+import numpy as np
+import torchvision
+from omegaconf import OmegaConf
+from PIL import Image
+from tqdm import tqdm, trange
+from itertools import islice
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from torch import autocast
+from contextlib import nullcontext
+import time
+from pytorch_lightning import seed_everything
+
+from ldm.util import instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+import math
+import copy
+import torch.nn.functional as F
+import cv2
+from util_image import ImageSpliterTh
+from pathlib import Path
+from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim"):])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")] #[250,]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+def load_img(path):
+ image = Image.open(path).convert("RGB")
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h}) from {path}")
+ w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.*image - 1.
+
+def read_image(im_path):
+ im = np.array(Image.open(im_path).convert("RGB"))
+ im = im.astype(np.float32)/255.0
+ im = im[None].transpose(0,3,1,2)
+ im = (torch.from_numpy(im) - 0.5) / 0.5
+
+ return im.cuda()
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--init-img",
+ type=str,
+ nargs="?",
+ help="path to the input image",
+ default="inputs/user_upload"
+ )
+ parser.add_argument(
+ "--outdir",
+ type=str,
+ nargs="?",
+ help="dir to write results to",
+ default="outputs/user_upload"
+ )
+ parser.add_argument(
+ "--ddpm_steps",
+ type=int,
+ default=1000,
+ help="number of ddpm sampling steps",
+ )
+ parser.add_argument(
+ "--n_iter",
+ type=int,
+ default=1,
+ help="sample this often",
+ )
+ parser.add_argument(
+ "--C",
+ type=int,
+ default=4,
+ help="latent channels",
+ )
+ parser.add_argument(
+ "--f",
+ type=int,
+ default=8,
+ help="downsampling factor, most often 8 or 16",
+ )
+ parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=1,
+ help="how many samples to produce for each given prompt. A.k.a batch size",
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/stable-diffusion/v1-inference.yaml",
+ help="path to config which constructs model",
+ )
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="./stablesr_000117.ckpt",
+ help="path to checkpoint of model",
+ )
+ parser.add_argument(
+ "--vqgan_ckpt",
+ type=str,
+ default="./vqgan_cfw_00011.ckpt",
+ help="path to checkpoint of VQGAN model",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="the seed (for reproducible sampling)",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ help="evaluate at this precision",
+ choices=["full", "autocast"],
+ default="autocast"
+ )
+ parser.add_argument(
+ "--dec_w",
+ type=float,
+ default=0.5,
+ help="weight for combining VQGAN and Diffusion",
+ )
+ parser.add_argument(
+ "--tile_overlap",
+ type=int,
+ default=32,
+ help="tile overlap size (in latent)",
+ )
+ parser.add_argument(
+ "--upscale",
+ type=float,
+ default=4.0,
+ help="upsample scale",
+ )
+ parser.add_argument(
+ "--colorfix_type",
+ type=str,
+ default="nofix",
+ help="Color fix type to adjust the color of HR result according to LR input: adain (used in paper); wavelet; nofix",
+ )
+ parser.add_argument(
+ "--vqgantile_stride",
+ type=int,
+ default=1000,
+ help="the stride for tile operation before VQGAN decoder (in pixel)",
+ )
+ parser.add_argument(
+ "--vqgantile_size",
+ type=int,
+ default=1280,
+ help="the size for tile operation before VQGAN decoder (in pixel)",
+ )
+ parser.add_argument(
+ "--input_size",
+ type=int,
+ default=512,
+ help="input size",
+ )
+
+ opt = parser.parse_args()
+ seed_everything(opt.seed)
+
+ print('>>>>>>>>>>color correction>>>>>>>>>>>')
+ if opt.colorfix_type == 'adain':
+ print('Use adain color correction')
+ elif opt.colorfix_type == 'wavelet':
+ print('Use wavelet color correction')
+ else:
+ print('No color correction')
+ print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
+
+ config = OmegaConf.load(f"{opt.config}")
+ model = load_model_from_config(config, f"{opt.ckpt}")
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ model = model.to(device)
+
+ model.configs = config
+
+ vqgan_config = OmegaConf.load("configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
+ vq_model = load_model_from_config(vqgan_config, opt.vqgan_ckpt)
+ vq_model = vq_model.to(device)
+ vq_model.decoder.fusion_w = opt.dec_w
+
+ os.makedirs(opt.outdir, exist_ok=True)
+ outpath = opt.outdir
+
+ batch_size = opt.n_samples
+
+ images_path_ori = sorted(glob.glob(os.path.join(opt.init_img, "*")))
+ images_path = copy.deepcopy(images_path_ori)
+ for item in images_path_ori:
+ img_name = item.split('/')[-1]
+ if os.path.exists(os.path.join(outpath, img_name)):
+ images_path.remove(item)
+ print(f"Found {len(images_path)} inputs.")
+
+ model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
+ model.num_timesteps = 1000
+
+ sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
+ sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)
+
+ use_timesteps = set(space_timesteps(1000, [opt.ddpm_steps]))
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ timestep_map = []
+ for i, alpha_cumprod in enumerate(model.alphas_cumprod):
+ if i in use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ timestep_map.append(i)
+ new_betas = [beta.data.cpu().numpy() for beta in new_betas]
+ model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
+ model.num_timesteps = 1000
+ model.ori_timesteps = list(use_timesteps)
+ model.ori_timesteps.sort()
+ model = model.to(device)
+
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
+ niqe_list = []
+ with torch.no_grad():
+ with precision_scope("cuda"):
+ with model.ema_scope():
+ tic = time.time()
+ all_samples = list()
+ for n in trange(len(images_path), desc="Sampling"):
+ if (n + 1) % opt.n_samples == 1 or opt.n_samples == 1:
+ cur_image = read_image(images_path[n])
+ size_min = min(cur_image.size(-1), cur_image.size(-2))
+ upsample_scale = max(opt.input_size/size_min, opt.upscale)
+ cur_image = F.interpolate(
+ cur_image,
+ size=(int(cur_image.size(-2)*upsample_scale),
+ int(cur_image.size(-1)*upsample_scale)),
+ mode='bicubic',
+ )
+ cur_image = cur_image.clamp(-1, 1)
+ im_lq_bs = [cur_image, ] # 1 x c x h x w, [-1, 1]
+ im_path_bs = [images_path[n], ]
+ else:
+ cur_image = read_image(images_path[n])
+ size_min = min(cur_image.size(-1), cur_image.size(-2))
+ upsample_scale = max(opt.input_size/size_min, opt.upscale)
+ cur_image = F.interpolate(
+ cur_image,
+ size=(int(cur_image.size(-2)*upsample_scale),
+ int(cur_image.size(-1)*upsample_scale)),
+ mode='bicubic',
+ )
+ cur_image = cur_image.clamp(-1, 1)
+ im_lq_bs.append(cur_image) # 1 x c x h x w, [-1, 1]
+ im_path_bs.append(images_path[n]) # 1 x c x h x w, [-1, 1]
+
+ if (n + 1) % opt.n_samples == 0 or (n+1) == len(images_path):
+ im_lq_bs = torch.cat(im_lq_bs, dim=0)
+ ori_h, ori_w = im_lq_bs.shape[2:]
+ ref_patch=None
+ if not (ori_h % 32 == 0 and ori_w % 32 == 0):
+ flag_pad = True
+ pad_h = ((ori_h // 32) + 1) * 32 - ori_h
+ pad_w = ((ori_w // 32) + 1) * 32 - ori_w
+ im_lq_bs = F.pad(im_lq_bs, pad=(0, pad_w, 0, pad_h), mode='reflect')
+ else:
+ flag_pad = False
+
+ if im_lq_bs.shape[2] > opt.vqgantile_size or im_lq_bs.shape[3] > opt.vqgantile_size:
+ im_spliter = ImageSpliterTh(im_lq_bs, opt.vqgantile_size, opt.vqgantile_stride, sf=1)
+ for im_lq_pch, index_infos in im_spliter:
+ seed_everything(opt.seed)
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_pch)) # move to latent space
+ text_init = ['']*opt.n_samples
+ semantic_c = model.cond_stage_model(text_init)
+ noise = torch.randn_like(init_latent)
+ # If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
+ t = repeat(torch.tensor([999]), '1 -> b', b=im_lq_bs.size(0))
+ t = t.to(device).long()
+ x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
+ # x_T = noise
+ samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_pch.size(0), timesteps=opt.ddpm_steps, time_replace=opt.ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(opt.input_size/8), tile_overlap=opt.tile_overlap, batch_size_sample=opt.n_samples)
+ _, enc_fea_lq = vq_model.encode(im_lq_pch)
+ x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
+ if opt.colorfix_type == 'adain':
+ x_samples = adaptive_instance_normalization(x_samples, im_lq_pch)
+ elif opt.colorfix_type == 'wavelet':
+ x_samples = wavelet_reconstruction(x_samples, im_lq_pch)
+ im_spliter.update(x_samples, index_infos)
+ im_sr = im_spliter.gather()
+ im_sr = torch.clamp((im_sr+1.0)/2.0, min=0.0, max=1.0)
+ else:
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_bs)) # move to latent space
+ text_init = ['']*opt.n_samples
+ semantic_c = model.cond_stage_model(text_init)
+ noise = torch.randn_like(init_latent)
+ # If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
+ t = repeat(torch.tensor([999]), '1 -> b', b=im_lq_bs.size(0))
+ t = t.to(device).long()
+ x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
+ # x_T = noise
+ samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_bs.size(0), timesteps=opt.ddpm_steps, time_replace=opt.ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(opt.input_size/8), tile_overlap=opt.tile_overlap, batch_size_sample=opt.n_samples)
+ _, enc_fea_lq = vq_model.encode(im_lq_bs)
+ x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
+ if opt.colorfix_type == 'adain':
+ x_samples = adaptive_instance_normalization(x_samples, im_lq_bs)
+ elif opt.colorfix_type == 'wavelet':
+ x_samples = wavelet_reconstruction(x_samples, im_lq_bs)
+ im_sr = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0)
+
+ if upsample_scale > opt.upscale:
+ im_sr = F.interpolate(
+ im_sr,
+ size=(int(im_lq_bs.size(-2)*opt.upscale/upsample_scale),
+ int(im_lq_bs.size(-1)*opt.upscale/upsample_scale)),
+ mode='bicubic',
+ )
+ im_sr = torch.clamp(im_sr, min=0.0, max=1.0)
+
+ im_sr = im_sr.cpu().numpy().transpose(0,2,3,1)*255 # b x h x w x c
+
+ if flag_pad:
+ im_sr = im_sr[:, :ori_h, :ori_w, ]
+
+ for jj in range(im_lq_bs.shape[0]):
+ img_name = str(Path(im_path_bs[jj]).name)
+ basename = os.path.splitext(os.path.basename(img_name))[0]
+ outpath = str(Path(opt.outdir)) + '/' + basename + '.png'
+ Image.fromarray(im_sr[jj, ].astype(np.uint8)).save(outpath)
+
+ toc = time.time()
+
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
+ f" \nEnjoy.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/StableSR/scripts/util_image.py b/StableSR/scripts/util_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..812bbb859b5e93c49b23baa6d47aa8d6ae5c5a4a
--- /dev/null
+++ b/StableSR/scripts/util_image.py
@@ -0,0 +1,793 @@
+#!/usr/bin/env python
+# -*- coding:utf-8 -*-
+# Power by Zongsheng Yue 2021-11-24 16:54:19
+
+import sys
+import cv2
+import math
+import torch
+import random
+import numpy as np
+from scipy import fft
+from pathlib import Path
+from einops import rearrange
+from skimage import img_as_ubyte, img_as_float32
+
+# --------------------------Metrics----------------------------
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+def calculate_ssim(im1, im2, border=0, ycbcr=False):
+ '''
+ SSIM the same outputs as MATLAB's
+ im1, im2: h x w x , [0, 255], uint8
+ '''
+ if not im1.shape == im2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+
+ if ycbcr:
+ im1 = rgb2ycbcr(im1, True)
+ im2 = rgb2ycbcr(im2, True)
+
+ h, w = im1.shape[:2]
+ im1 = im1[border:h-border, border:w-border]
+ im2 = im2[border:h-border, border:w-border]
+
+ if im1.ndim == 2:
+ return ssim(im1, im2)
+ elif im1.ndim == 3:
+ if im1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(im1[:,:,i], im2[:,:,i]))
+ return np.array(ssims).mean()
+ elif im1.shape[2] == 1:
+ return ssim(np.squeeze(im1), np.squeeze(im2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+def calculate_psnr(im1, im2, border=0, ycbcr=False):
+ '''
+ PSNR metric.
+ im1, im2: h x w x , [0, 255], uint8
+ '''
+ if not im1.shape == im2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+
+ if ycbcr:
+ im1 = rgb2ycbcr(im1, True)
+ im2 = rgb2ycbcr(im2, True)
+
+ h, w = im1.shape[:2]
+ im1 = im1[border:h-border, border:w-border]
+ im2 = im2[border:h-border, border:w-border]
+
+ im1 = im1.astype(np.float64)
+ im2 = im2.astype(np.float64)
+ mse = np.mean((im1 - im2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+def batch_PSNR(img, imclean, border=0, ycbcr=False):
+ if ycbcr:
+ img = rgb2ycbcrTorch(img, True)
+ imclean = rgb2ycbcrTorch(imclean, True)
+ Img = img.data.cpu().numpy()
+ Iclean = imclean.data.cpu().numpy()
+ Img = img_as_ubyte(Img)
+ Iclean = img_as_ubyte(Iclean)
+ PSNR = 0
+ h, w = Iclean.shape[2:]
+ for i in range(Img.shape[0]):
+ PSNR += calculate_psnr(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border)
+ return PSNR
+
+def batch_SSIM(img, imclean, border=0, ycbcr=False):
+ if ycbcr:
+ img = rgb2ycbcrTorch(img, True)
+ imclean = rgb2ycbcrTorch(imclean, True)
+ Img = img.data.cpu().numpy()
+ Iclean = imclean.data.cpu().numpy()
+ Img = img_as_ubyte(Img)
+ Iclean = img_as_ubyte(Iclean)
+ SSIM = 0
+ for i in range(Img.shape[0]):
+ SSIM += calculate_ssim(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border)
+ return SSIM
+
+def normalize_np(im, mean=0.5, std=0.5, reverse=False):
+ '''
+ Input:
+ im: h x w x c, numpy array
+ Normalize: (im - mean) / std
+ Reverse: im * std + mean
+
+ '''
+ if not isinstance(mean, (list, tuple)):
+ mean = [mean, ] * im.shape[2]
+ mean = np.array(mean).reshape([1, 1, im.shape[2]])
+
+ if not isinstance(std, (list, tuple)):
+ std = [std, ] * im.shape[2]
+ std = np.array(std).reshape([1, 1, im.shape[2]])
+
+ if not reverse:
+ out = (im.astype(np.float32) - mean) / std
+ else:
+ out = im.astype(np.float32) * std + mean
+ return out
+
+def normalize_th(im, mean=0.5, std=0.5, reverse=False):
+ '''
+ Input:
+ im: b x c x h x w, torch tensor
+ Normalize: (im - mean) / std
+ Reverse: im * std + mean
+
+ '''
+ if not isinstance(mean, (list, tuple)):
+ mean = [mean, ] * im.shape[1]
+ mean = torch.tensor(mean, device=im.device).view([1, im.shape[1], 1, 1])
+
+ if not isinstance(std, (list, tuple)):
+ std = [std, ] * im.shape[1]
+ std = torch.tensor(std, device=im.device).view([1, im.shape[1], 1, 1])
+
+ if not reverse:
+ out = (im - mean) / std
+ else:
+ out = im * std + mean
+ return out
+
+# ------------------------Image format--------------------------
+def rgb2ycbcr(im, only_y=True):
+ '''
+ same as matlab rgb2ycbcr
+ Input:
+ im: uint8 [0,255] or float [0,1]
+ only_y: only return Y channel
+ '''
+ # transform to float64 data type, range [0, 255]
+ if im.dtype == np.uint8:
+ im_temp = im.astype(np.float64)
+ else:
+ im_temp = (im * 255).astype(np.float64)
+
+ # convert
+ if only_y:
+ rlt = np.dot(im_temp, np.array([65.481, 128.553, 24.966])/ 255.0) + 16.0
+ else:
+ rlt = np.matmul(im_temp, np.array([[65.481, -37.797, 112.0 ],
+ [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]])/255.0) + [16, 128, 128]
+ if im.dtype == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(im.dtype)
+
+def rgb2ycbcrTorch(im, only_y=True):
+ '''
+ same as matlab rgb2ycbcr
+ Input:
+ im: float [0,1], N x 3 x H x W
+ only_y: only return Y channel
+ '''
+ # transform to range [0,255.0]
+ im_temp = im.permute([0,2,3,1]) * 255.0 # N x H x W x C --> N x H x W x C
+ # convert
+ if only_y:
+ rlt = torch.matmul(im_temp, torch.tensor([65.481, 128.553, 24.966],
+ device=im.device, dtype=im.dtype).view([3,1])/ 255.0) + 16.0
+ else:
+ rlt = torch.matmul(im_temp, torch.tensor([[65.481, -37.797, 112.0 ],
+ [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]],
+ device=im.device, dtype=im.dtype)/255.0) + \
+ torch.tensor([16, 128, 128]).view([-1, 1, 1, 3])
+ rlt /= 255.0
+ rlt.clamp_(0.0, 1.0)
+ return rlt.permute([0, 3, 1, 2])
+
+def bgr2rgb(im): return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+
+def rgb2bgr(im): return cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+ flag_tensor = torch.is_tensor(tensor)
+ if flag_tensor:
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1 and flag_tensor:
+ result = result[0]
+ return result
+
+def img2tensor(imgs, out_type=torch.float32):
+ """Convert image numpy arrays into torch tensor.
+ Args:
+ imgs (Array or list[array]): Accept shapes:
+ 3) list of numpy arrays
+ 1) 3D numpy array of shape (H x W x 3/1);
+ 2) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+
+ Returns:
+ (array or list): 4D ndarray of shape (1 x C x H x W)
+ """
+
+ def _img2tensor(img):
+ if img.ndim == 2:
+ tensor = torch.from_numpy(img[None, None,]).type(out_type)
+ elif img.ndim == 3:
+ tensor = torch.from_numpy(rearrange(img, 'h w c -> c h w')).type(out_type).unsqueeze(0)
+ else:
+ raise TypeError(f'2D or 3D numpy array expected, got{img.ndim}D array')
+ return tensor
+
+ if not (isinstance(imgs, np.ndarray) or (isinstance(imgs, list) and all(isinstance(t, np.ndarray) for t in imgs))):
+ raise TypeError(f'Numpy array or list of numpy array expected, got {type(imgs)}')
+
+ flag_numpy = isinstance(imgs, np.ndarray)
+ if flag_numpy:
+ imgs = [imgs,]
+ result = []
+ for _img in imgs:
+ result.append(_img2tensor(_img))
+
+ if len(result) == 1 and flag_numpy:
+ result = result[0]
+ return result
+
+# ------------------------Image I/O-----------------------------
+def imread(path, chn='rgb', dtype='float32'):
+ '''
+ Read image.
+ chn: 'rgb', 'bgr' or 'gray'
+ out:
+ im: h x w x c, numpy tensor
+ '''
+ im = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) # BGR, uint8
+ try:
+ if chn.lower() == 'rgb':
+ if im.ndim == 3:
+ im = bgr2rgb(im)
+ else:
+ im = np.stack((im, im, im), axis=2)
+ elif chn.lower() == 'gray':
+ assert im.ndim == 2
+ except:
+ print(str(path))
+
+ if dtype == 'float32':
+ im = im.astype(np.float32) / 255.
+ elif dtype == 'float64':
+ im = im.astype(np.float64) / 255.
+ elif dtype == 'uint8':
+ pass
+ else:
+ sys.exit('Please input corrected dtype: float32, float64 or uint8!')
+
+ return im
+
+def imwrite(im_in, path, chn='rgb', dtype_in='float32', qf=None):
+ '''
+ Save image.
+ Input:
+ im: h x w x c, numpy tensor
+ path: the saving path
+ chn: the channel order of the im,
+ '''
+ im = im_in.copy()
+ if isinstance(path, str):
+ path = Path(path)
+ if dtype_in != 'uint8':
+ im = img_as_ubyte(im)
+
+ if chn.lower() == 'rgb' and im.ndim == 3:
+ im = rgb2bgr(im)
+
+ if qf is not None and path.suffix.lower() in ['.jpg', '.jpeg']:
+ flag = cv2.imwrite(str(path), im, [int(cv2.IMWRITE_JPEG_QUALITY), int(qf)])
+ else:
+ flag = cv2.imwrite(str(path), im)
+
+ return flag
+
+def jpeg_compress(im, qf, chn_in='rgb'):
+ '''
+ Input:
+ im: h x w x 3 array
+ qf: compress factor, (0, 100]
+ chn_in: 'rgb' or 'bgr'
+ Return:
+ Compressed Image with channel order: chn_in
+ '''
+ # transform to BGR channle and uint8 data type
+ im_bgr = rgb2bgr(im) if chn_in.lower() == 'rgb' else im
+ if im.dtype != np.dtype('uint8'): im_bgr = img_as_ubyte(im_bgr)
+
+ # JPEG compress
+ flag, encimg = cv2.imencode('.jpg', im_bgr, [int(cv2.IMWRITE_JPEG_QUALITY), qf])
+ assert flag
+ im_jpg_bgr = cv2.imdecode(encimg, 1) # uint8, BGR
+
+ # transform back to original channel and the original data type
+ im_out = bgr2rgb(im_jpg_bgr) if chn_in.lower() == 'rgb' else im_jpg_bgr
+ if im.dtype != np.dtype('uint8'): im_out = img_as_float32(im_out).astype(im.dtype)
+ return im_out
+
+# ------------------------Augmentation-----------------------------
+def data_aug_np(image, mode):
+ '''
+ Performs data augmentation of the input image
+ Input:
+ image: a cv2 (OpenCV) image
+ mode: int. Choice of transformation to apply to the image
+ 0 - no transformation
+ 1 - flip up and down
+ 2 - rotate counterwise 90 degree
+ 3 - rotate 90 degree and flip up and down
+ 4 - rotate 180 degree
+ 5 - rotate 180 degree and flip
+ 6 - rotate 270 degree
+ 7 - rotate 270 degree and flip
+ '''
+ if mode == 0:
+ # original
+ out = image
+ elif mode == 1:
+ # flip up and down
+ out = np.flipud(image)
+ elif mode == 2:
+ # rotate counterwise 90 degree
+ out = np.rot90(image)
+ elif mode == 3:
+ # rotate 90 degree and flip up and down
+ out = np.rot90(image)
+ out = np.flipud(out)
+ elif mode == 4:
+ # rotate 180 degree
+ out = np.rot90(image, k=2)
+ elif mode == 5:
+ # rotate 180 degree and flip
+ out = np.rot90(image, k=2)
+ out = np.flipud(out)
+ elif mode == 6:
+ # rotate 270 degree
+ out = np.rot90(image, k=3)
+ elif mode == 7:
+ # rotate 270 degree and flip
+ out = np.rot90(image, k=3)
+ out = np.flipud(out)
+ else:
+ raise Exception('Invalid choice of image transformation')
+
+ return out.copy()
+
+def inverse_data_aug_np(image, mode):
+ '''
+ Performs inverse data augmentation of the input image
+ '''
+ if mode == 0:
+ # original
+ out = image
+ elif mode == 1:
+ out = np.flipud(image)
+ elif mode == 2:
+ out = np.rot90(image, axes=(1,0))
+ elif mode == 3:
+ out = np.flipud(image)
+ out = np.rot90(out, axes=(1,0))
+ elif mode == 4:
+ out = np.rot90(image, k=2, axes=(1,0))
+ elif mode == 5:
+ out = np.flipud(image)
+ out = np.rot90(out, k=2, axes=(1,0))
+ elif mode == 6:
+ out = np.rot90(image, k=3, axes=(1,0))
+ elif mode == 7:
+ # rotate 270 degree and flip
+ out = np.flipud(image)
+ out = np.rot90(out, k=3, axes=(1,0))
+ else:
+ raise Exception('Invalid choice of image transformation')
+
+ return out
+
+class SpatialAug:
+ def __init__(self):
+ pass
+
+ def __call__(self, im, flag=None):
+ if flag is None:
+ flag = random.randint(0, 7)
+
+ out = data_aug_np(im, flag)
+ return out
+
+# ----------------------Visualization----------------------------
+def imshow(x, title=None, cbar=False):
+ import matplotlib.pyplot as plt
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+# -----------------------Covolution------------------------------
+def imgrad(im, pading_mode='mirror'):
+ '''
+ Calculate image gradient.
+ Input:
+ im: h x w x c numpy array
+ '''
+ from scipy.ndimage import correlate # lazy import
+ wx = np.array([[0, 0, 0],
+ [-1, 1, 0],
+ [0, 0, 0]], dtype=np.float32)
+ wy = np.array([[0, -1, 0],
+ [0, 1, 0],
+ [0, 0, 0]], dtype=np.float32)
+ if im.ndim == 3:
+ gradx = np.stack(
+ [correlate(im[:,:,c], wx, mode=pading_mode) for c in range(im.shape[2])],
+ axis=2
+ )
+ grady = np.stack(
+ [correlate(im[:,:,c], wy, mode=pading_mode) for c in range(im.shape[2])],
+ axis=2
+ )
+ grad = np.concatenate((gradx, grady), axis=2)
+ else:
+ gradx = correlate(im, wx, mode=pading_mode)
+ grady = correlate(im, wy, mode=pading_mode)
+ grad = np.stack((gradx, grady), axis=2)
+
+ return {'gradx': gradx, 'grady': grady, 'grad':grad}
+
+def imgrad_fft(im):
+ '''
+ Calculate image gradient.
+ Input:
+ im: h x w x c numpy array
+ '''
+ wx = np.rot90(np.array([[0, 0, 0],
+ [-1, 1, 0],
+ [0, 0, 0]], dtype=np.float32), k=2)
+ gradx = convfft(im, wx)
+ wy = np.rot90(np.array([[0, -1, 0],
+ [0, 1, 0],
+ [0, 0, 0]], dtype=np.float32), k=2)
+ grady = convfft(im, wy)
+ grad = np.concatenate((gradx, grady), axis=2)
+
+ return {'gradx': gradx, 'grady': grady, 'grad':grad}
+
+def convfft(im, weight):
+ '''
+ Convolution with FFT
+ Input:
+ im: h1 x w1 x c numpy array
+ weight: h2 x w2 numpy array
+ Output:
+ out: h1 x w1 x c numpy array
+ '''
+ axes = (0,1)
+ otf = psf2otf(weight, im.shape[:2])
+ if im.ndim == 3:
+ otf = np.tile(otf[:, :, None], (1,1,im.shape[2]))
+ out = fft.ifft2(fft.fft2(im, axes=axes) * otf, axes=axes).real
+ return out
+
+def psf2otf(psf, shape):
+ """
+ MATLAB psf2otf function.
+ Borrowed from https://github.com/aboucaud/pypher/blob/master/pypher/pypher.py.
+ Input:
+ psf : h x w numpy array
+ shape : list or tuple, output shape of the OTF array
+ Output:
+ otf : OTF array with the desirable shape
+ """
+ if np.all(psf == 0):
+ return np.zeros_like(psf)
+
+ inshape = psf.shape
+ # Pad the PSF to outsize
+ psf = zero_pad(psf, shape, position='corner')
+
+ # Circularly shift OTF so that the 'center' of the PSF is [0,0] element of the array
+ for axis, axis_size in enumerate(inshape):
+ psf = np.roll(psf, -int(axis_size / 2), axis=axis)
+
+ # Compute the OTF
+ otf = fft.fft2(psf)
+
+ # Estimate the rough number of operations involved in the FFT
+ # and discard the PSF imaginary part if within roundoff error
+ # roundoff error = machine epsilon = sys.float_info.epsilon
+ # or np.finfo().eps
+ n_ops = np.sum(psf.size * np.log2(psf.shape))
+ otf = np.real_if_close(otf, tol=n_ops)
+
+ return otf
+
+# ----------------------Patch Cropping----------------------------
+def random_crop(im, pch_size):
+ '''
+ Randomly crop a patch from the give image.
+ '''
+ h, w = im.shape[:2]
+ if h == pch_size and w == pch_size:
+ im_pch = im
+ else:
+ assert h >= pch_size or w >= pch_size
+ ind_h = random.randint(0, h-pch_size)
+ ind_w = random.randint(0, w-pch_size)
+ im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]
+
+ return im_pch
+
+class RandomCrop:
+ def __init__(self, pch_size):
+ self.pch_size = pch_size
+
+ def __call__(self, im):
+ return random_crop(im, self.pch_size)
+
+class ImageSpliterNp:
+ def __init__(self, im, pch_size, stride, sf=1):
+ '''
+ Input:
+ im: h x w x c, numpy array, [0, 1], low-resolution image in SR
+ pch_size, stride: patch setting
+ sf: scale factor in image super-resolution
+ '''
+ assert stride <= pch_size
+ self.stride = stride
+ self.pch_size = pch_size
+ self.sf = sf
+
+ if im.ndim == 2:
+ im = im[:, :, None]
+
+ height, width, chn = im.shape
+ self.height_starts_list = self.extract_starts(height)
+ self.width_starts_list = self.extract_starts(width)
+ self.length = self.__len__()
+ self.num_pchs = 0
+
+ self.im_ori = im
+ self.im_res = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)
+ self.pixel_count = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)
+
+ def extract_starts(self, length):
+ starts = list(range(0, length, self.stride))
+ if starts[-1] + self.pch_size > length:
+ starts[-1] = length - self.pch_size
+ return starts
+
+ def __len__(self):
+ return len(self.height_starts_list) * len(self.width_starts_list)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.num_pchs < self.length:
+ w_start_idx = self.num_pchs // len(self.height_starts_list)
+ w_start = self.width_starts_list[w_start_idx] * self.sf
+ w_end = w_start + self.pch_size * self.sf
+
+ h_start_idx = self.num_pchs % len(self.height_starts_list)
+ h_start = self.height_starts_list[h_start_idx] * self.sf
+ h_end = h_start + self.pch_size * self.sf
+
+ pch = self.im_ori[h_start:h_end, w_start:w_end,]
+ self.w_start, self.w_end = w_start, w_end
+ self.h_start, self.h_end = h_start, h_end
+
+ self.num_pchs += 1
+ else:
+ raise StopIteration(0)
+
+ return pch, (h_start, h_end, w_start, w_end)
+
+ def update(self, pch_res, index_infos):
+ '''
+ Input:
+ pch_res: pch_size x pch_size x 3, [0,1]
+ index_infos: (h_start, h_end, w_start, w_end)
+ '''
+ if index_infos is None:
+ w_start, w_end = self.w_start, self.w_end
+ h_start, h_end = self.h_start, self.h_end
+ else:
+ h_start, h_end, w_start, w_end = index_infos
+
+ self.im_res[h_start:h_end, w_start:w_end] += pch_res
+ self.pixel_count[h_start:h_end, w_start:w_end] += 1
+
+ def gather(self):
+ assert np.all(self.pixel_count != 0)
+ return self.im_res / self.pixel_count
+
+class ImageSpliterTh:
+ def __init__(self, im, pch_size, stride, sf=1):
+ '''
+ Input:
+ im: n x c x h x w, torch tensor, float, low-resolution image in SR
+ pch_size, stride: patch setting
+ sf: scale factor in image super-resolution
+ '''
+ assert stride <= pch_size
+ self.stride = stride
+ self.pch_size = pch_size
+ self.sf = sf
+
+ bs, chn, height, width= im.shape
+ self.height_starts_list = self.extract_starts(height)
+ self.width_starts_list = self.extract_starts(width)
+ self.length = self.__len__()
+ self.num_pchs = 0
+
+ self.im_ori = im
+ self.im_res = torch.zeros([bs, chn, height*sf, width*sf], dtype=im.dtype, device=im.device)
+ self.pixel_count = torch.zeros([bs, chn, height*sf, width*sf], dtype=im.dtype, device=im.device)
+
+ def extract_starts(self, length):
+ if length <= self.pch_size:
+ starts = [0,]
+ else:
+ starts = list(range(0, length, self.stride))
+ for i in range(len(starts)):
+ if starts[i] + self.pch_size > length:
+ starts[i] = length - self.pch_size
+ starts = sorted(set(starts), key=starts.index)
+ return starts
+
+ def __len__(self):
+ return len(self.height_starts_list) * len(self.width_starts_list)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.num_pchs < self.length:
+ w_start_idx = self.num_pchs // len(self.height_starts_list)
+ w_start = self.width_starts_list[w_start_idx]
+ w_end = w_start + self.pch_size
+
+ h_start_idx = self.num_pchs % len(self.height_starts_list)
+ h_start = self.height_starts_list[h_start_idx]
+ h_end = h_start + self.pch_size
+
+ pch = self.im_ori[:, :, h_start:h_end, w_start:w_end,]
+
+ h_start *= self.sf
+ h_end *= self.sf
+ w_start *= self.sf
+ w_end *= self.sf
+
+ self.w_start, self.w_end = w_start, w_end
+ self.h_start, self.h_end = h_start, h_end
+
+ self.num_pchs += 1
+ else:
+ raise StopIteration()
+
+ return pch, (h_start, h_end, w_start, w_end)
+
+ def update(self, pch_res, index_infos):
+ '''
+ Input:
+ pch_res: n x c x pch_size x pch_size, float
+ index_infos: (h_start, h_end, w_start, w_end)
+ '''
+ if index_infos is None:
+ w_start, w_end = self.w_start, self.w_end
+ h_start, h_end = self.h_start, self.h_end
+ else:
+ h_start, h_end, w_start, w_end = index_infos
+
+ self.im_res[:, :, h_start:h_end, w_start:w_end] += pch_res
+ self.pixel_count[:, :, h_start:h_end, w_start:w_end] += 1
+
+ def gather(self):
+ assert torch.all(self.pixel_count != 0)
+ return self.im_res.div(self.pixel_count)
+
+# ----------------------Patch Cropping----------------------------
+class Clamper:
+ def __init__(self, min_max=(-1, 1)):
+ self.min_bound, self.max_bound = min_max[0], min_max[1]
+
+ def __call__(self, im):
+ if isinstance(im, np.ndarray):
+ return np.clip(im, a_min=self.min_bound, a_max=self.max_bound)
+ elif isinstance(im, torch.Tensor):
+ return torch.clamp(im, min=self.min_bound, max=self.max_bound)
+ else:
+ raise TypeError(f'ndarray or Tensor expected, got {type(im)}')
+
+if __name__ == '__main__':
+ im = np.random.randn(64, 64, 3).astype(np.float32)
+
+ grad1 = imgrad(im)['grad']
+ grad2 = imgrad_fft(im)['grad']
+
+ error = np.abs(grad1 -grad2).max()
+ mean_error = np.abs(grad1 -grad2).mean()
+ print('The largest error is {:.2e}'.format(error))
+ print('The mean error is {:.2e}'.format(mean_error))
diff --git a/StableSR/scripts/wavelet_color_fix.py b/StableSR/scripts/wavelet_color_fix.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e8fa852476775161571e849bf5eca1fca1a36b2
--- /dev/null
+++ b/StableSR/scripts/wavelet_color_fix.py
@@ -0,0 +1,119 @@
+'''
+# --------------------------------------------------------------------------------
+# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
+# --------------------------------------------------------------------------------
+'''
+
+import torch
+from PIL import Image
+from torch import Tensor
+from torch.nn import functional as F
+
+from torchvision.transforms import ToTensor, ToPILImage
+
+def adain_color_fix(target: Image, source: Image):
+ # Convert images to tensors
+ to_tensor = ToTensor()
+ target_tensor = to_tensor(target).unsqueeze(0)
+ source_tensor = to_tensor(source).unsqueeze(0)
+
+ # Apply adaptive instance normalization
+ result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)
+
+ # Convert tensor back to image
+ to_image = ToPILImage()
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
+
+ return result_image
+
+def wavelet_color_fix(target: Image, source: Image):
+ # Convert images to tensors
+ to_tensor = ToTensor()
+ target_tensor = to_tensor(target).unsqueeze(0)
+ source_tensor = to_tensor(source).unsqueeze(0)
+
+ # Apply wavelet reconstruction
+ result_tensor = wavelet_reconstruction(target_tensor, source_tensor)
+
+ # Convert tensor back to image
+ to_image = ToPILImage()
+ result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))
+
+ return result_image
+
+def calc_mean_std(feat: Tensor, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
+ feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
+ return feat_mean, feat_std
+
+def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
+ """Adaptive instance normalization.
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+def wavelet_blur(image: Tensor, radius: int):
+ """
+ Apply wavelet blur to the input tensor.
+ """
+ # input shape: (1, 3, H, W)
+ # convolution kernel
+ kernel_vals = [
+ [0.0625, 0.125, 0.0625],
+ [0.125, 0.25, 0.125],
+ [0.0625, 0.125, 0.0625],
+ ]
+ kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
+ # add channel dimensions to the kernel to make it a 4D tensor
+ kernel = kernel[None, None]
+ # repeat the kernel across all input channels
+ kernel = kernel.repeat(3, 1, 1, 1)
+ image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
+ # apply convolution
+ output = F.conv2d(image, kernel, groups=3, dilation=radius)
+ return output
+
+def wavelet_decomposition(image: Tensor, levels=5):
+ """
+ Apply wavelet decomposition to the input tensor.
+ This function only returns the low frequency & the high frequency.
+ """
+ high_freq = torch.zeros_like(image)
+ for i in range(levels):
+ radius = 2 ** i
+ low_freq = wavelet_blur(image, radius)
+ high_freq += (image - low_freq)
+ image = low_freq
+
+ return high_freq, low_freq
+
+def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
+ """
+ Apply wavelet decomposition, so that the content will have the same color as the style.
+ """
+ # calculate the wavelet decomposition of the content feature
+ content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
+ del content_low_freq
+ # calculate the wavelet decomposition of the style feature
+ style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
+ del style_high_freq
+ # reconstruct the content feature with the style's high frequency
+ return content_high_freq + style_low_freq
diff --git a/StableSR/setup.py b/StableSR/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..65926c7efa09951216f9b10d8776a4a3aeebc363
--- /dev/null
+++ b/StableSR/setup.py
@@ -0,0 +1,13 @@
+from setuptools import setup, find_packages
+
+setup(
+ name='StableSR',
+ version='0.0.1',
+ description='',
+ packages=find_packages(),
+ install_requires=[
+ 'torch',
+ 'numpy',
+ 'tqdm',
+ ],
+)
diff --git a/StableSR/taming/data/ade20k.py b/StableSR/taming/data/ade20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..366dae97207dbb8356598d636e14ad084d45bc76
--- /dev/null
+++ b/StableSR/taming/data/ade20k.py
@@ -0,0 +1,124 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+from taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/ade20k_examples.txt",
+ data_root="data/ade20k_images",
+ segmentation_root="data/ade20k_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=151, shift_segmentation=False)
+
+
+# With semantic map and scene label
+class ADE20kBase(Dataset):
+ def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
+ self.split = self.get_split()
+ self.n_labels = 151 # unknown + 150
+ self.data_csv = {"train": "data/ade20k_train.txt",
+ "validation": "data/ade20k_test.txt"}[self.split]
+ self.data_root = "data/ade20k_root"
+ with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
+ self.scene_categories = f.read().splitlines()
+ self.scene_categories = dict(line.split() for line in self.scene_categories)
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, "images", l)
+ for l in self.image_paths],
+ "relative_segmentation_path_": [l.replace(".jpg", ".png")
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.data_root, "annotations",
+ l.replace(".jpg", ".png"))
+ for l in self.image_paths],
+ "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
+ for l in self.image_paths],
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size if size is not None else None
+ else:
+ self.crop_size = crop_size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+
+ if crop_size is not None:
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image, mask=segmentation)
+ else:
+ processed = {"image": image, "mask": segmentation}
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class ADE20kTrain(ADE20kBase):
+ # default to random_crop=True
+ def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ interpolation=interpolation, crop_size=crop_size)
+
+ def get_split(self):
+ return "train"
+
+
+class ADE20kValidation(ADE20kBase):
+ def get_split(self):
+ return "validation"
+
+
+if __name__ == "__main__":
+ dset = ADE20kValidation()
+ ex = dset[0]
+ for k in ["image", "scene_category", "segmentation"]:
+ print(type(ex[k]))
+ try:
+ print(ex[k].shape)
+ except:
+ print(ex[k])
diff --git a/StableSR/taming/data/annotated_objects_coco.py b/StableSR/taming/data/annotated_objects_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..af000ecd943d7b8a85d7eb70195c9ecd10ab5edc
--- /dev/null
+++ b/StableSR/taming/data/annotated_objects_coco.py
@@ -0,0 +1,139 @@
+import json
+from itertools import chain
+from pathlib import Path
+from typing import Iterable, Dict, List, Callable, Any
+from collections import defaultdict
+
+from tqdm import tqdm
+
+from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from taming.data.helper_types import Annotation, ImageDescription, Category
+
+COCO_PATH_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_train2017.json',
+ 'stuff_annotations': 'annotations/stuff_train2017.json',
+ 'files': 'train2017'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_val2017.json',
+ 'stuff_annotations': 'annotations/stuff_val2017.json',
+ 'files': 'val2017'
+ }
+}
+
+
+def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
+ return {
+ str(img['id']): ImageDescription(
+ id=img['id'],
+ license=img.get('license'),
+ file_name=img['file_name'],
+ coco_url=img['coco_url'],
+ original_size=(img['width'], img['height']),
+ date_captured=img.get('date_captured'),
+ flickr_url=img.get('flickr_url')
+ )
+ for img in description_json
+ }
+
+
+def load_categories(category_json: Iterable) -> Dict[str, Category]:
+ return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
+ for cat in category_json if cat['name'] != 'other'}
+
+
+def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
+ category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
+ annotations = defaultdict(list)
+ total = sum(len(a) for a in annotations_json)
+ for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
+ image_id = str(ann['image_id'])
+ if image_id not in image_descriptions:
+ raise ValueError(f'image_id [{image_id}] has no image description.')
+ category_id = ann['category_id']
+ try:
+ category_no = category_no_for_id(str(category_id))
+ except KeyError:
+ continue
+
+ width, height = image_descriptions[image_id].original_size
+ bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
+
+ annotations[image_id].append(
+ Annotation(
+ id=ann['id'],
+ area=bbox[2]*bbox[3], # use bbox area
+ is_group_of=ann['iscrowd'],
+ image_id=ann['image_id'],
+ bbox=bbox,
+ category_id=str(category_id),
+ category_no=category_no
+ )
+ )
+ return dict(annotations)
+
+
+class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
+ def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ coco/
+ ├── annotations
+ │ ├── instances_train2017.json
+ │ ├── instances_val2017.json
+ │ ├── stuff_train2017.json
+ │ └── stuff_val2017.json
+ ├── train2017
+ │ ├── 000000000009.jpg
+ │ ├── 000000000025.jpg
+ │ └── ...
+ ├── val2017
+ │ ├── 000000000139.jpg
+ │ ├── 000000000285.jpg
+ │ └── ...
+ @param: split: one of 'train' or 'validation'
+ @param: desired image size (give square images)
+ """
+ super().__init__(**kwargs)
+ self.use_things = use_things
+ self.use_stuff = use_stuff
+
+ with open(self.paths['instances_annotations']) as f:
+ inst_data_json = json.load(f)
+ with open(self.paths['stuff_annotations']) as f:
+ stuff_data_json = json.load(f)
+
+ category_jsons = []
+ annotation_jsons = []
+ if self.use_things:
+ category_jsons.append(inst_data_json['categories'])
+ annotation_jsons.append(inst_data_json['annotations'])
+ if self.use_stuff:
+ category_jsons.append(stuff_data_json['categories'])
+ annotation_jsons.append(stuff_data_json['annotations'])
+
+ self.categories = load_categories(chain(*category_jsons))
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = load_image_descriptions(inst_data_json['images'])
+ annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area,
+ self.min_objects_per_image, self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in COCO_PATH_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
+ return COCO_PATH_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ # noinspection PyProtectedMember
+ return self.image_descriptions[image_id]._asdict()
diff --git a/StableSR/taming/data/annotated_objects_dataset.py b/StableSR/taming/data/annotated_objects_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..53cc346a1c76289a4964d7dc8a29582172f33dc0
--- /dev/null
+++ b/StableSR/taming/data/annotated_objects_dataset.py
@@ -0,0 +1,218 @@
+from pathlib import Path
+from typing import Optional, List, Callable, Dict, Any, Union
+import warnings
+
+import PIL.Image as pil_image
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
+from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from taming.data.conditional_builder.utils import load_object_from_string
+from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
+from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
+ Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
+
+
+class AnnotatedObjectsDataset(Dataset):
+ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
+ min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
+ crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
+ encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
+ no_object_classes: Optional[int] = None):
+ self.data_path = data_path
+ self.split = split
+ self.keys = keys
+ self.target_image_size = target_image_size
+ self.min_object_area = min_object_area
+ self.min_objects_per_image = min_objects_per_image
+ self.max_objects_per_image = max_objects_per_image
+ self.crop_method = crop_method
+ self.random_flip = random_flip
+ self.no_tokens = no_tokens
+ self.use_group_parameter = use_group_parameter
+ self.encode_crop = encode_crop
+
+ self.annotations = None
+ self.image_descriptions = None
+ self.categories = None
+ self.category_ids = None
+ self.category_number = None
+ self.image_ids = None
+ self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
+ self.paths = self.build_paths(self.data_path)
+ self._conditional_builders = None
+ self.category_allow_list = None
+ if category_allow_list_target:
+ allow_list = load_object_from_string(category_allow_list_target)
+ self.category_allow_list = {name for name, _ in allow_list}
+ self.category_mapping = {}
+ if category_mapping_target:
+ self.category_mapping = load_object_from_string(category_mapping_target)
+ self.no_object_classes = no_object_classes
+
+ def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
+ top_level = Path(top_level)
+ sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
+ for path in sub_paths.values():
+ if not path.exists():
+ raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
+ return sub_paths
+
+ @staticmethod
+ def load_image_from_disk(path: Path) -> Image:
+ return pil_image.open(path).convert('RGB')
+
+ @staticmethod
+ def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
+ transform_functions = []
+ if crop_method == 'none':
+ transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
+ elif crop_method == 'center':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ CenterCropReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-1d':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ RandomCrop1dReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-2d':
+ transform_functions.extend([
+ Random2dCropReturnCoordinates(target_image_size),
+ transforms.Resize(target_image_size)
+ ])
+ elif crop_method is None:
+ return None
+ else:
+ raise ValueError(f'Received invalid crop method [{crop_method}].')
+ if random_flip:
+ transform_functions.append(RandomHorizontalFlipReturn())
+ transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
+ return transform_functions
+
+ def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
+ crop_bbox = None
+ flipped = None
+ for t in self.transform_functions:
+ if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
+ crop_bbox, x = t(x)
+ elif isinstance(t, RandomHorizontalFlipReturn):
+ flipped, x = t(x)
+ else:
+ x = t(x)
+ return crop_bbox, flipped, x
+
+ @property
+ def no_classes(self) -> int:
+ return self.no_object_classes if self.no_object_classes else len(self.categories)
+
+ @property
+ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
+ # cannot set this up in init because no_classes is only known after loading data in init of superclass
+ if self._conditional_builders is None:
+ self._conditional_builders = {
+ 'objects_center_points': ObjectsCenterPointsConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ ),
+ 'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ )
+ }
+ return self._conditional_builders
+
+ def filter_categories(self) -> None:
+ if self.category_allow_list:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
+ if self.category_mapping:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
+
+ def setup_category_id_and_number(self) -> None:
+ self.category_ids = list(self.categories.keys())
+ self.category_ids.sort()
+ if '/m/01s55n' in self.category_ids:
+ self.category_ids.remove('/m/01s55n')
+ self.category_ids.append('/m/01s55n')
+ self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
+ if self.category_allow_list is not None and self.category_mapping is None \
+ and len(self.category_ids) != len(self.category_allow_list):
+ warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
+ 'Make sure all names in category_allow_list exist.')
+
+ def clean_up_annotations_and_image_descriptions(self) -> None:
+ image_id_set = set(self.image_ids)
+ self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
+ self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
+
+ @staticmethod
+ def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
+ min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
+ filtered = {}
+ for image_id, annotations in all_annotations.items():
+ annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
+ if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
+ filtered[image_id] = annotations_with_min_area
+ return filtered
+
+ def __len__(self):
+ return len(self.image_ids)
+
+ def __getitem__(self, n: int) -> Dict[str, Any]:
+ image_id = self.get_image_id(n)
+ sample = self.get_image_description(image_id)
+ sample['annotations'] = self.get_annotation(image_id)
+
+ if 'image' in self.keys:
+ sample['image_path'] = str(self.get_image_path(image_id))
+ sample['image'] = self.load_image_from_disk(sample['image_path'])
+ sample['image'] = convert_pil_to_tensor(sample['image'])
+ sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
+ sample['image'] = sample['image'].permute(1, 2, 0)
+
+ for conditional, builder in self.conditional_builders.items():
+ if conditional in self.keys:
+ sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
+
+ if self.keys:
+ # only return specified keys
+ sample = {key: sample[key] for key in self.keys}
+ return sample
+
+ def get_image_id(self, no: int) -> str:
+ return self.image_ids[no]
+
+ def get_annotation(self, image_id: str) -> str:
+ return self.annotations[image_id]
+
+ def get_textual_label_for_category_id(self, category_id: str) -> str:
+ return self.categories[category_id].name
+
+ def get_textual_label_for_category_no(self, category_no: int) -> str:
+ return self.categories[self.get_category_id(category_no)].name
+
+ def get_category_number(self, category_id: str) -> int:
+ return self.category_number[category_id]
+
+ def get_category_id(self, category_no: int) -> str:
+ return self.category_ids[category_no]
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ raise NotImplementedError()
+
+ def get_path_structure(self):
+ raise NotImplementedError
+
+ def get_image_path(self, image_id: str) -> Path:
+ raise NotImplementedError
diff --git a/StableSR/taming/data/annotated_objects_open_images.py b/StableSR/taming/data/annotated_objects_open_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..aede6803d2cef7a74ca784e7907d35fba6c71239
--- /dev/null
+++ b/StableSR/taming/data/annotated_objects_open_images.py
@@ -0,0 +1,137 @@
+from collections import defaultdict
+from csv import DictReader, reader as TupleReader
+from pathlib import Path
+from typing import Dict, List, Any
+import warnings
+
+from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from taming.data.helper_types import Annotation, Category
+from tqdm import tqdm
+
+OPEN_IMAGES_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'oidv6-train-annotations-bbox.csv',
+ 'file_list': 'train-images-boxable.csv',
+ 'files': 'train'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'validation-annotations-bbox.csv',
+ 'file_list': 'validation-images.csv',
+ 'files': 'validation'
+ },
+ 'test': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'test-annotations-bbox.csv',
+ 'file_list': 'test-images.csv',
+ 'files': 'test'
+ }
+}
+
+
+def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
+ category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
+ annotations: Dict[str, List[Annotation]] = defaultdict(list)
+ with open(descriptor_path) as file:
+ reader = DictReader(file)
+ for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
+ width = float(row['XMax']) - float(row['XMin'])
+ height = float(row['YMax']) - float(row['YMin'])
+ area = width * height
+ category_id = row['LabelName']
+ if category_id in category_mapping:
+ category_id = category_mapping[category_id]
+ if area >= min_object_area and category_id in category_no_for_id:
+ annotations[row['ImageID']].append(
+ Annotation(
+ id=i,
+ image_id=row['ImageID'],
+ source=row['Source'],
+ category_id=category_id,
+ category_no=category_no_for_id[category_id],
+ confidence=float(row['Confidence']),
+ bbox=(float(row['XMin']), float(row['YMin']), width, height),
+ area=area,
+ is_occluded=bool(int(row['IsOccluded'])),
+ is_truncated=bool(int(row['IsTruncated'])),
+ is_group_of=bool(int(row['IsGroupOf'])),
+ is_depiction=bool(int(row['IsDepiction'])),
+ is_inside=bool(int(row['IsInside']))
+ )
+ )
+ if 'train' in str(descriptor_path) and i < 14000000:
+ warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
+ return dict(annotations)
+
+
+def load_image_ids(csv_path: Path) -> List[str]:
+ with open(csv_path) as file:
+ reader = DictReader(file)
+ return [row['image_name'] for row in reader]
+
+
+def load_categories(csv_path: Path) -> Dict[str, Category]:
+ with open(csv_path) as file:
+ reader = TupleReader(file)
+ return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
+
+
+class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
+ def __init__(self, use_additional_parameters: bool, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ open_images/
+ │ oidv6-train-annotations-bbox.csv
+ ├── class-descriptions-boxable.csv
+ ├── oidv6-train-annotations-bbox.csv
+ ├── test
+ │ ├── 000026e7ee790996.jpg
+ │ ├── 000062a39995e348.jpg
+ │ └── ...
+ ├── test-annotations-bbox.csv
+ ├── test-images.csv
+ ├── train
+ │ ├── 000002b66c9c498e.jpg
+ │ ├── 000002b97e5471a0.jpg
+ │ └── ...
+ ├── train-images-boxable.csv
+ ├── validation
+ │ ├── 0001eeaf4aed83f9.jpg
+ │ ├── 0004886b7d043cfd.jpg
+ │ └── ...
+ ├── validation-annotations-bbox.csv
+ └── validation-images.csv
+ @param: split: one of 'train', 'validation' or 'test'
+ @param: desired image size (returns square images)
+ """
+
+ super().__init__(**kwargs)
+ self.use_additional_parameters = use_additional_parameters
+
+ self.categories = load_categories(self.paths['class_descriptions'])
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = {}
+ annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
+ self.category_number)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
+ self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in OPEN_IMAGES_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
+ return OPEN_IMAGES_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ image_path = self.get_image_path(image_id)
+ return {'file_path': str(image_path), 'file_name': image_path.name}
diff --git a/StableSR/taming/data/base.py b/StableSR/taming/data/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..e21667df4ce4baa6bb6aad9f8679bd756e2ffdb7
--- /dev/null
+++ b/StableSR/taming/data/base.py
@@ -0,0 +1,70 @@
+import bisect
+import numpy as np
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset, ConcatDataset
+
+
+class ConcatDatasetWithIndex(ConcatDataset):
+ """Modified from original pytorch code to return dataset idx"""
+ def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx], dataset_idx
+
+
+class ImagePaths(Dataset):
+ def __init__(self, paths, size=None, random_crop=False, labels=None):
+ self.size = size
+ self.random_crop = random_crop
+
+ self.labels = dict() if labels is None else labels
+ self.labels["file_path_"] = paths
+ self._length = len(paths)
+
+ if self.size is not None and self.size > 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ if not self.random_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
+ self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return self._length
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def __getitem__(self, i):
+ example = dict()
+ example["image"] = self.preprocess_image(self.labels["file_path_"][i])
+ for k in self.labels:
+ example[k] = self.labels[k][i]
+ return example
+
+
+class NumpyPaths(ImagePaths):
+ def preprocess_image(self, image_path):
+ image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
+ image = np.transpose(image, (1,2,0))
+ image = Image.fromarray(image, mode="RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
diff --git a/StableSR/taming/data/coco.py b/StableSR/taming/data/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b2f7838448cb63dcf96daffe9470d58566d975a
--- /dev/null
+++ b/StableSR/taming/data/coco.py
@@ -0,0 +1,176 @@
+import os
+import json
+import albumentations
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset
+
+from taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/coco_examples.txt",
+ data_root="data/coco_images",
+ segmentation_root="data/coco_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=183, shift_segmentation=True)
+
+
+class CocoBase(Dataset):
+ """needed for (image, caption, segmentation) pairs"""
+ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
+ crop_size=None, force_no_crop=False, given_files=None):
+ self.split = self.get_split()
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size
+ else:
+ self.crop_size = crop_size
+
+ self.onehot = onehot_segmentation # return segmentation as rgb or one hot
+ self.stuffthing = use_stuffthing # include thing in segmentation
+ if self.onehot and not self.stuffthing:
+ raise NotImplemented("One hot mode is only supported for the "
+ "stuffthings version because labels are stored "
+ "a bit different.")
+
+ data_json = datajson
+ with open(data_json) as json_file:
+ self.json_data = json.load(json_file)
+ self.img_id_to_captions = dict()
+ self.img_id_to_filepath = dict()
+ self.img_id_to_segmentation_filepath = dict()
+
+ assert data_json.split("/")[-1] in ["captions_train2017.json",
+ "captions_val2017.json"]
+ if self.stuffthing:
+ self.segmentation_prefix = (
+ "data/cocostuffthings/val2017" if
+ data_json.endswith("captions_val2017.json") else
+ "data/cocostuffthings/train2017")
+ else:
+ self.segmentation_prefix = (
+ "data/coco/annotations/stuff_val2017_pixelmaps" if
+ data_json.endswith("captions_val2017.json") else
+ "data/coco/annotations/stuff_train2017_pixelmaps")
+
+ imagedirs = self.json_data["images"]
+ self.labels = {"image_ids": list()}
+ for imgdir in tqdm(imagedirs, desc="ImgToPath"):
+ self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
+ self.img_id_to_captions[imgdir["id"]] = list()
+ pngfilename = imgdir["file_name"].replace("jpg", "png")
+ self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
+ self.segmentation_prefix, pngfilename)
+ if given_files is not None:
+ if pngfilename in given_files:
+ self.labels["image_ids"].append(imgdir["id"])
+ else:
+ self.labels["image_ids"].append(imgdir["id"])
+
+ capdirs = self.json_data["annotations"]
+ for capdir in tqdm(capdirs, desc="ImgToCaptions"):
+ # there are in average 5 captions per image
+ self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
+
+ self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
+ if self.split=="validation":
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler, self.cropper],
+ additional_targets={"segmentation": "image"})
+ if force_no_crop:
+ self.rescaler = albumentations.Resize(height=self.size, width=self.size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler],
+ additional_targets={"segmentation": "image"})
+
+ def __len__(self):
+ return len(self.labels["image_ids"])
+
+ def preprocess_image(self, image_path, segmentation_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+
+ segmentation = Image.open(segmentation_path)
+ if not self.onehot and not segmentation.mode == "RGB":
+ segmentation = segmentation.convert("RGB")
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.onehot:
+ assert self.stuffthing
+ # stored in caffe format: unlabeled==255. stuff and thing from
+ # 0-181. to be compatible with the labels in
+ # https://github.com/nightrome/cocostuff/blob/master/labels.txt
+ # we shift stuffthing one to the right and put unlabeled in zero
+ # as long as segmentation is uint8 shifting to right handles the
+ # latter too
+ assert segmentation.dtype == np.uint8
+ segmentation = segmentation + 1
+
+ processed = self.preprocessor(image=image, segmentation=segmentation)
+ image, segmentation = processed["image"], processed["segmentation"]
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ if self.onehot:
+ assert segmentation.dtype == np.uint8
+ # make it one hot
+ n_labels = 183
+ flatseg = np.ravel(segmentation)
+ onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
+ onehot[np.arange(flatseg.size), flatseg] = True
+ onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
+ segmentation = onehot
+ else:
+ segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
+ return image, segmentation
+
+ def __getitem__(self, i):
+ img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
+ seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
+ image, segmentation = self.preprocess_image(img_path, seg_path)
+ captions = self.img_id_to_captions[self.labels["image_ids"][i]]
+ # randomly draw one of all available captions per image
+ caption = captions[np.random.randint(0, len(captions))]
+ example = {"image": image,
+ "caption": [str(caption[0])],
+ "segmentation": segmentation,
+ "img_path": img_path,
+ "seg_path": seg_path,
+ "filename_": img_path.split(os.sep)[-1]
+ }
+ return example
+
+
+class CocoImagesAndCaptionsTrain(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
+ super().__init__(size=size,
+ dataroot="data/coco/train2017",
+ datajson="data/coco/annotations/captions_train2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
+
+ def get_split(self):
+ return "train"
+
+
+class CocoImagesAndCaptionsValidation(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
+ given_files=None):
+ super().__init__(size=size,
+ dataroot="data/coco/val2017",
+ datajson="data/coco/annotations/captions_val2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
+ given_files=given_files)
+
+ def get_split(self):
+ return "validation"
diff --git a/StableSR/taming/data/conditional_builder/objects_bbox.py b/StableSR/taming/data/conditional_builder/objects_bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..15881e76b7ab2a914df8f2dfe08ae4f0c6c511b5
--- /dev/null
+++ b/StableSR/taming/data/conditional_builder/objects_bbox.py
@@ -0,0 +1,60 @@
+from itertools import cycle
+from typing import List, Tuple, Callable, Optional
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+from taming.data.helper_types import BoundingBox, Annotation
+from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
+ pad_list, get_plot_font_size, absolute_bbox
+
+
+class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
+ @property
+ def object_descriptor_length(self) -> int:
+ return 3
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_triples = [
+ (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
+ for ann in annotations
+ ]
+ empty_triple = (self.none, self.none, self.none)
+ object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
+ return object_triples
+
+ def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ object_triples = grouper(conditional_list, 3)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
+ for object_triple in object_triples if object_triple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ font = ImageFont.truetype(
+ "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
+ size=get_plot_font_size(font_size, figure_size)
+ )
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
+ annotation = self.representation_to_annotation(representation)
+ class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
+ bbox = absolute_bbox(bbox, width, height)
+ draw.rectangle(bbox, outline=color, width=line_width)
+ draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
diff --git a/StableSR/taming/data/conditional_builder/objects_center_points.py b/StableSR/taming/data/conditional_builder/objects_center_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a480329cc47fb38a7b8729d424e092b77d40749
--- /dev/null
+++ b/StableSR/taming/data/conditional_builder/objects_center_points.py
@@ -0,0 +1,168 @@
+import math
+import random
+import warnings
+from itertools import cycle
+from typing import List, Optional, Tuple, Callable
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
+ additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
+ absolute_bbox, rescale_annotations
+from taming.data.helper_types import BoundingBox, Annotation
+from taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+
+class ObjectsCenterPointsConditionalBuilder:
+ def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
+ use_group_parameter: bool, use_additional_parameters: bool):
+ self.no_object_classes = no_object_classes
+ self.no_max_objects = no_max_objects
+ self.no_tokens = no_tokens
+ self.encode_crop = encode_crop
+ self.no_sections = int(math.sqrt(self.no_tokens))
+ self.use_group_parameter = use_group_parameter
+ self.use_additional_parameters = use_additional_parameters
+
+ @property
+ def none(self) -> int:
+ return self.no_tokens - 1
+
+ @property
+ def object_descriptor_length(self) -> int:
+ return 2
+
+ @property
+ def embedding_dim(self) -> int:
+ extra_length = 2 if self.encode_crop else 0
+ return self.no_max_objects * self.object_descriptor_length + extra_length
+
+ def tokenize_coordinates(self, x: float, y: float) -> int:
+ """
+ Express 2d coordinates with one number.
+ Example: assume self.no_tokens = 16, then no_sections = 4:
+ 0 0 0 0
+ 0 0 # 0
+ 0 0 0 0
+ 0 0 0 x
+ Then the # position corresponds to token 6, the x position to token 15.
+ @param x: float in [0, 1]
+ @param y: float in [0, 1]
+ @return: discrete tokenized coordinate
+ """
+ x_discrete = int(round(x * (self.no_sections - 1)))
+ y_discrete = int(round(y * (self.no_sections - 1)))
+ return y_discrete * self.no_sections + x_discrete
+
+ def coordinates_from_token(self, token: int) -> (float, float):
+ x = token % self.no_sections
+ y = token // self.no_sections
+ return x / (self.no_sections - 1), y / (self.no_sections - 1)
+
+ def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
+ x0, y0 = self.coordinates_from_token(token1)
+ x1, y1 = self.coordinates_from_token(token2)
+ return x0, y0, x1 - x0, y1 - y0
+
+ def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
+ return self.tokenize_coordinates(bbox[0], bbox[1]), \
+ self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
+
+ def inverse_build(self, conditional: LongTensor) \
+ -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ table_of_content = grouper(conditional_list, self.object_descriptor_length)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
+ for object_tuple in table_of_content if object_tuple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ circle_size = get_circle_size(figure_size)
+ font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
+ size=get_plot_font_size(font_size, figure_size))
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
+ x_abs, y_abs = x * width, y * height
+ ann = self.representation_to_annotation(representation)
+ label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
+ ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
+ draw.ellipse(ellipse_bbox, fill=color, width=0)
+ draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
+
+ def object_representation(self, annotation: Annotation) -> int:
+ modifier = 0
+ if self.use_group_parameter:
+ modifier |= 1 * (annotation.is_group_of is True)
+ if self.use_additional_parameters:
+ modifier |= 2 * (annotation.is_occluded is True)
+ modifier |= 4 * (annotation.is_depiction is True)
+ modifier |= 8 * (annotation.is_inside is True)
+ return annotation.category_no + self.no_object_classes * modifier
+
+ def representation_to_annotation(self, representation: int) -> Annotation:
+ category_no = representation % self.no_object_classes
+ modifier = representation // self.no_object_classes
+ # noinspection PyTypeChecker
+ return Annotation(
+ area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
+ category_no=category_no,
+ is_group_of=bool((modifier & 1) * self.use_group_parameter),
+ is_occluded=bool((modifier & 2) * self.use_additional_parameters),
+ is_depiction=bool((modifier & 4) * self.use_additional_parameters),
+ is_inside=bool((modifier & 8) * self.use_additional_parameters)
+ )
+
+ def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
+ return list(self.token_pair_from_bbox(crop_coordinates))
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_tuples = [
+ (self.object_representation(a),
+ self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
+ for a in annotations
+ ]
+ empty_tuple = (self.none, self.none)
+ object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
+ return object_tuples
+
+ def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
+ -> LongTensor:
+ if len(annotations) == 0:
+ warnings.warn('Did not receive any annotations.')
+ if len(annotations) > self.no_max_objects:
+ warnings.warn('Received more annotations than allowed.')
+ annotations = annotations[:self.no_max_objects]
+
+ if not crop_coordinates:
+ crop_coordinates = FULL_CROP
+
+ random.shuffle(annotations)
+ annotations = filter_annotations(annotations, crop_coordinates)
+ if self.encode_crop:
+ annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
+ if horizontal_flip:
+ crop_coordinates = horizontally_flip_bbox(crop_coordinates)
+ extra = self._crop_encoder(crop_coordinates)
+ else:
+ annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
+ extra = []
+
+ object_tuples = self._make_object_descriptors(annotations)
+ flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
+ assert len(flattened) == self.embedding_dim
+ assert all(0 <= value < self.no_tokens for value in flattened)
+ return LongTensor(flattened)
diff --git a/StableSR/taming/data/conditional_builder/utils.py b/StableSR/taming/data/conditional_builder/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0ee175f2e05a80dbc71c22acbecb22dddadbb42
--- /dev/null
+++ b/StableSR/taming/data/conditional_builder/utils.py
@@ -0,0 +1,105 @@
+import importlib
+from typing import List, Any, Tuple, Optional
+
+from taming.data.helper_types import BoundingBox, Annotation
+
+# source: seaborn, color palette tab10
+COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
+ (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
+BLACK = (0, 0, 0)
+GRAY_75 = (63, 63, 63)
+GRAY_50 = (127, 127, 127)
+GRAY_25 = (191, 191, 191)
+WHITE = (255, 255, 255)
+FULL_CROP = (0., 0., 1., 1.)
+
+
+def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
+ """
+ Give intersection area of two rectangles.
+ @param rectangle1: (x0, y0, w, h) of first rectangle
+ @param rectangle2: (x0, y0, w, h) of second rectangle
+ """
+ rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
+ rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
+ x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
+ y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
+ return x_overlap * y_overlap
+
+
+def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
+ return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
+
+
+def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
+ bbox = relative_bbox
+ bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
+ return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+
+
+def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
+ return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
+
+
+def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
+ List[Annotation]:
+ def clamp(x: float):
+ return max(min(x, 1.), 0.)
+
+ def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ if flip:
+ x0 = 1 - (x0 + w)
+ return x0, y0, w, h
+
+ return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
+
+
+def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
+ return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
+
+
+def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
+ sl = slice(1) if short else slice(None)
+ string = ''
+ if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
+ return string
+ if annotation.is_group_of:
+ string += 'group'[sl] + ','
+ if annotation.is_occluded:
+ string += 'occluded'[sl] + ','
+ if annotation.is_depiction:
+ string += 'depiction'[sl] + ','
+ if annotation.is_inside:
+ string += 'inside'[sl]
+ return '(' + string.strip(",") + ')'
+
+
+def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
+ if font_size is None:
+ font_size = 10
+ if max(figure_size) >= 256:
+ font_size = 12
+ if max(figure_size) >= 512:
+ font_size = 15
+ return font_size
+
+
+def get_circle_size(figure_size: Tuple[int, int]) -> int:
+ circle_size = 2
+ if max(figure_size) >= 256:
+ circle_size = 3
+ if max(figure_size) >= 512:
+ circle_size = 4
+ return circle_size
+
+
+def load_object_from_string(object_string: str) -> Any:
+ """
+ Source: https://stackoverflow.com/a/10773699
+ """
+ module_name, class_name = object_string.rsplit(".", 1)
+ return getattr(importlib.import_module(module_name), class_name)
diff --git a/StableSR/taming/data/custom.py b/StableSR/taming/data/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..33f302a4b55ba1e8ec282ec3292b6263c06dfb91
--- /dev/null
+++ b/StableSR/taming/data/custom.py
@@ -0,0 +1,38 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class CustomBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ return example
+
+
+
+class CustomTrain(CustomBase):
+ def __init__(self, size, training_images_list_file):
+ super().__init__()
+ with open(training_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
+class CustomTest(CustomBase):
+ def __init__(self, size, test_images_list_file):
+ super().__init__()
+ with open(test_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
diff --git a/StableSR/taming/data/faceshq.py b/StableSR/taming/data/faceshq.py
new file mode 100644
index 0000000000000000000000000000000000000000..6912d04b66a6d464c1078e4b51d5da290f5e767e
--- /dev/null
+++ b/StableSR/taming/data/faceshq.py
@@ -0,0 +1,134 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class FacesBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+ self.keys = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ ex = {}
+ if self.keys is not None:
+ for k in self.keys:
+ ex[k] = example[k]
+ else:
+ ex = example
+ return ex
+
+
+class CelebAHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class CelebAHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FacesHQTrain(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQTrain(size=size, keys=keys)
+ d2 = FFHQTrain(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
+
+
+class FacesHQValidation(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQValidation(size=size, keys=keys)
+ d2 = FFHQValidation(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
diff --git a/StableSR/taming/data/helper_types.py b/StableSR/taming/data/helper_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb51e301da08602cfead5961c4f7e1d89f6aba79
--- /dev/null
+++ b/StableSR/taming/data/helper_types.py
@@ -0,0 +1,49 @@
+from typing import Dict, Tuple, Optional, NamedTuple, Union
+from PIL.Image import Image as pil_image
+from torch import Tensor
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+Image = Union[Tensor, pil_image]
+BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
+CropMethodType = Literal['none', 'random', 'center', 'random-2d']
+SplitType = Literal['train', 'validation', 'test']
+
+
+class ImageDescription(NamedTuple):
+ id: int
+ file_name: str
+ original_size: Tuple[int, int] # w, h
+ url: Optional[str] = None
+ license: Optional[int] = None
+ coco_url: Optional[str] = None
+ date_captured: Optional[str] = None
+ flickr_url: Optional[str] = None
+ flickr_id: Optional[str] = None
+ coco_id: Optional[str] = None
+
+
+class Category(NamedTuple):
+ id: str
+ super_category: Optional[str]
+ name: str
+
+
+class Annotation(NamedTuple):
+ area: float
+ image_id: str
+ bbox: BoundingBox
+ category_no: int
+ category_id: str
+ id: Optional[int] = None
+ source: Optional[str] = None
+ confidence: Optional[float] = None
+ is_group_of: Optional[bool] = None
+ is_truncated: Optional[bool] = None
+ is_occluded: Optional[bool] = None
+ is_depiction: Optional[bool] = None
+ is_inside: Optional[bool] = None
+ segmentation: Optional[Dict] = None
diff --git a/StableSR/taming/data/image_transforms.py b/StableSR/taming/data/image_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..657ac332174e0ac72f68315271ffbd757b771a0f
--- /dev/null
+++ b/StableSR/taming/data/image_transforms.py
@@ -0,0 +1,132 @@
+import random
+import warnings
+from typing import Union
+
+import torch
+from torch import Tensor
+from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
+from torchvision.transforms.functional import _get_image_size as get_image_size
+
+from taming.data.helper_types import BoundingBox, Image
+
+pil_to_tensor = PILToTensor()
+
+
+def convert_pil_to_tensor(image: Image) -> Tensor:
+ with warnings.catch_warnings():
+ # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
+ warnings.simplefilter("ignore")
+ return pil_to_tensor(image)
+
+
+class RandomCrop1dReturnCoordinates(RandomCrop):
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+ if self.padding is not None:
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
+
+ width, height = get_image_size(img)
+ # pad the width if needed
+ if self.pad_if_needed and width < self.size[1]:
+ padding = [self.size[1] - width, 0]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+ # pad the height if needed
+ if self.pad_if_needed and height < self.size[0]:
+ padding = [0, self.size[0] - height]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+
+ i, j, h, w = self.get_params(img, self.size)
+ bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
+ return bbox, F.crop(img, i, j, h, w)
+
+
+class Random2dCropReturnCoordinates(torch.nn.Module):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+
+ def __init__(self, min_size: int):
+ super().__init__()
+ self.min_size = min_size
+
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ width, height = get_image_size(img)
+ max_size = min(width, height)
+ if max_size <= self.min_size:
+ size = max_size
+ else:
+ size = random.randint(self.min_size, max_size)
+ top = random.randint(0, height - size)
+ left = random.randint(0, width - size)
+ bbox = left / width, top / height, size / width, size / height
+ return bbox, F.crop(img, top, left, size, size)
+
+
+class CenterCropReturnCoordinates(CenterCrop):
+ @staticmethod
+ def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
+ if width > height:
+ w = height / width
+ h = 1.0
+ x0 = 0.5 - w / 2
+ y0 = 0.
+ else:
+ w = 1.0
+ h = width / height
+ x0 = 0.
+ y0 = 0.5 - h / 2
+ return x0, y0, w, h
+
+ def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ width, height = get_image_size(img)
+ return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
+
+
+class RandomHorizontalFlipReturn(RandomHorizontalFlip):
+ def forward(self, img: Image) -> (bool, Image):
+ """
+ Additionally to flipping, returns a boolean whether it was flipped or not.
+ Args:
+ img (PIL Image or Tensor): Image to be flipped.
+
+ Returns:
+ flipped: whether the image was flipped or not
+ PIL Image or Tensor: Randomly flipped image.
+
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ if torch.rand(1) < self.p:
+ return True, F.hflip(img)
+ return False, img
diff --git a/StableSR/taming/data/imagenet.py b/StableSR/taming/data/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a02ec44ba4af9e993f58c91fa43482a4ecbe54c
--- /dev/null
+++ b/StableSR/taming/data/imagenet.py
@@ -0,0 +1,558 @@
+import os, tarfile, glob, shutil
+import yaml
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+import albumentations
+from omegaconf import OmegaConf
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths
+from taming.util import download, retrieve
+import taming.data.utils as bdu
+
+
+def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
+ synsets = []
+ with open(path_to_yaml) as f:
+ di2s = yaml.load(f)
+ for idx in indices:
+ synsets.append(str(di2s[idx]))
+ print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
+ return synsets
+
+
+def str_to_indices(string):
+ """Expects a string in the format '32-123, 256, 280-321'"""
+ assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
+ subs = string.split(",")
+ indices = []
+ for sub in subs:
+ subsubs = sub.split("-")
+ assert len(subsubs) > 0
+ if len(subsubs) == 1:
+ indices.append(int(subsubs[0]))
+ else:
+ rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
+ indices.extend(rang)
+ return sorted(indices)
+
+
+class ImageNetBase(Dataset):
+ def __init__(self, config=None):
+ self.config = config or OmegaConf.create()
+ if not type(self.config)==dict:
+ self.config = OmegaConf.to_container(self.config)
+ self._prepare()
+ self._prepare_synset_to_human()
+ self._prepare_idx_to_synset()
+ self._load()
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def _prepare(self):
+ raise NotImplementedError()
+
+ def _filter_relpaths(self, relpaths):
+ ignore = set([
+ "n06596364_9591.JPEG",
+ ])
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
+ if "sub_indices" in self.config:
+ indices = str_to_indices(self.config["sub_indices"])
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
+ files = []
+ for rpath in relpaths:
+ syn = rpath.split("/")[0]
+ if syn in synsets:
+ files.append(rpath)
+ return files
+ else:
+ return relpaths
+
+ def _prepare_synset_to_human(self):
+ SIZE = 2655750
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
+ if (not os.path.exists(self.human_dict) or
+ not os.path.getsize(self.human_dict)==SIZE):
+ download(URL, self.human_dict)
+
+ def _prepare_idx_to_synset(self):
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
+ if (not os.path.exists(self.idx2syn)):
+ download(URL, self.idx2syn)
+
+ def _load(self):
+ with open(self.txt_filelist, "r") as f:
+ self.relpaths = f.read().splitlines()
+ l1 = len(self.relpaths)
+ self.relpaths = self._filter_relpaths(self.relpaths)
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
+
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
+
+ unique_synsets = np.unique(self.synsets)
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
+ self.class_labels = [class_dict[s] for s in self.synsets]
+
+ with open(self.human_dict, "r") as f:
+ human_dict = f.read().splitlines()
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
+
+ self.human_labels = [human_dict[s] for s in self.synsets]
+
+ labels = {
+ "relpath": np.array(self.relpaths),
+ "synsets": np.array(self.synsets),
+ "class_label": np.array(self.class_labels),
+ "human_label": np.array(self.human_labels),
+ }
+ self.data = ImagePaths(self.abspaths,
+ labels=labels,
+ size=retrieve(self.config, "size", default=0),
+ random_crop=self.random_crop)
+
+
+class ImageNetTrain(ImageNetBase):
+ NAME = "ILSVRC2012_train"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
+ FILES = [
+ "ILSVRC2012_img_train.tar",
+ ]
+ SIZES = [
+ 147897477120,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
+ default=True)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 1281167
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ print("Extracting sub-tars.")
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
+ for subpath in tqdm(subpaths):
+ subdir = subpath[:-len(".tar")]
+ os.makedirs(subdir, exist_ok=True)
+ with tarfile.open(subpath, "r:") as tar:
+ tar.extractall(path=subdir)
+
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+class ImageNetValidation(ImageNetBase):
+ NAME = "ILSVRC2012_validation"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
+ FILES = [
+ "ILSVRC2012_img_val.tar",
+ "validation_synset.txt",
+ ]
+ SIZES = [
+ 6744924160,
+ 1950000,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
+ default=False)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 50000
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ vspath = os.path.join(self.root, self.FILES[1])
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ download(self.VS_URL, vspath)
+
+ with open(vspath, "r") as f:
+ synset_dict = f.read().splitlines()
+ synset_dict = dict(line.split() for line in synset_dict)
+
+ print("Reorganizing into synset folders")
+ synsets = np.unique(list(synset_dict.values()))
+ for s in synsets:
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
+ for k, v in synset_dict.items():
+ src = os.path.join(datadir, k)
+ dst = os.path.join(datadir, v)
+ shutil.move(src, dst)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+def get_preprocessor(size=None, random_crop=False, additional_targets=None,
+ crop_size=None):
+ if size is not None and size > 0:
+ transforms = list()
+ rescaler = albumentations.SmallestMaxSize(max_size = size)
+ transforms.append(rescaler)
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=size,width=size)
+ transforms.append(cropper)
+ else:
+ cropper = albumentations.RandomCrop(height=size,width=size)
+ transforms.append(cropper)
+ flipper = albumentations.HorizontalFlip()
+ transforms.append(flipper)
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ elif crop_size is not None and crop_size > 0:
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ transforms = [cropper]
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ preprocessor = lambda **kwargs: kwargs
+ return preprocessor
+
+
+def rgba_to_depth(x):
+ assert x.dtype == np.uint8
+ assert len(x.shape) == 3 and x.shape[2] == 4
+ y = x.copy()
+ y.dtype = np.float32
+ y = y.reshape(x.shape[:2])
+ return np.ascontiguousarray(y)
+
+
+class BaseWithDepth(Dataset):
+ DEFAULT_DEPTH_ROOT="data/imagenet_depth"
+
+ def __init__(self, config=None, size=None, random_crop=False,
+ crop_size=None, root=None):
+ self.config = config
+ self.base_dset = self.get_base_dset()
+ self.preprocessor = get_preprocessor(
+ size=size,
+ crop_size=crop_size,
+ random_crop=random_crop,
+ additional_targets={"depth": "image"})
+ self.crop_size = crop_size
+ if self.crop_size is not None:
+ self.rescaler = albumentations.Compose(
+ [albumentations.SmallestMaxSize(max_size = self.crop_size)],
+ additional_targets={"depth": "image"})
+ if root is not None:
+ self.DEFAULT_DEPTH_ROOT = root
+
+ def __len__(self):
+ return len(self.base_dset)
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = self.base_dset[i]
+ e["depth"] = self.preprocess_depth(self.get_depth_path(e))
+ # up if necessary
+ h,w,c = e["image"].shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ out = self.rescaler(image=e["image"], depth=e["depth"])
+ e["image"] = out["image"]
+ e["depth"] = out["depth"]
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+class ImageNetTrainWithDepth(BaseWithDepth):
+ # default to random_crop=True
+ def __init__(self, random_crop=True, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetTrain()
+ else:
+ return ImageNetTrain({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
+ return fid
+
+
+class ImageNetValidationWithDepth(BaseWithDepth):
+ def __init__(self, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(**kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetValidation()
+ else:
+ return ImageNetValidation({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
+ return fid
+
+
+class RINTrainWithDepth(ImageNetTrainWithDepth):
+ def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class RINValidationWithDepth(ImageNetValidationWithDepth):
+ def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class DRINExamples(Dataset):
+ def __init__(self):
+ self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
+ with open("data/drin_examples.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ self.image_paths = [os.path.join("data/drin_images",
+ relpath) for relpath in relpaths]
+ self.depth_paths = [os.path.join("data/drin_depth",
+ relpath.replace(".JPEG", ".png")) for relpath in relpaths]
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = dict()
+ e["image"] = self.preprocess_image(self.image_paths[i])
+ e["depth"] = self.preprocess_depth(self.depth_paths[i])
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
+ if factor is None or factor==1:
+ return x
+
+ dtype = x.dtype
+ assert dtype in [np.float32, np.float64]
+ assert x.min() >= -1
+ assert x.max() <= 1
+
+ keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
+ "bicubic": Image.BICUBIC}[keepmode]
+
+ lr = (x+1.0)*127.5
+ lr = lr.clip(0,255).astype(np.uint8)
+ lr = Image.fromarray(lr)
+
+ h, w, _ = x.shape
+ nh = h//factor
+ nw = w//factor
+ assert nh > 0 and nw > 0, (nh, nw)
+
+ lr = lr.resize((nw,nh), Image.BICUBIC)
+ if keepshapes:
+ lr = lr.resize((w,h), keepmode)
+ lr = np.array(lr)/127.5-1.0
+ lr = lr.astype(dtype)
+
+ return lr
+
+
+class ImageNetScale(Dataset):
+ def __init__(self, size=None, crop_size=None, random_crop=False,
+ up_factor=None, hr_factor=None, keep_mode="bicubic"):
+ self.base = self.get_base()
+
+ self.size = size
+ self.crop_size = crop_size if crop_size is not None else self.size
+ self.random_crop = random_crop
+ self.up_factor = up_factor
+ self.hr_factor = hr_factor
+ self.keep_mode = keep_mode
+
+ transforms = list()
+
+ if self.size is not None and self.size > 0:
+ rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ self.rescaler = rescaler
+ transforms.append(rescaler)
+
+ if self.crop_size is not None and self.crop_size > 0:
+ if len(transforms) == 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
+
+ if not self.random_crop:
+ cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
+ transforms.append(cropper)
+
+ if len(transforms) > 0:
+ if self.up_factor is not None:
+ additional_targets = {"lr": "image"}
+ else:
+ additional_targets = None
+ self.preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return len(self.base)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ # adjust resolution
+ image = imscale(image, self.hr_factor, keepshapes=False)
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+ if self.up_factor is None:
+ image = self.preprocessor(image=image)["image"]
+ example["image"] = image
+ else:
+ lr = imscale(image, self.up_factor, keepshapes=True,
+ keepmode=self.keep_mode)
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+class ImageNetScaleTrain(ImageNetScale):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetScaleValidation(ImageNetScale):
+ def get_base(self):
+ return ImageNetValidation()
+
+
+from skimage.feature import canny
+from skimage.color import rgb2gray
+
+
+class ImageNetEdges(ImageNetScale):
+ def __init__(self, up_factor=1, **kwargs):
+ super().__init__(up_factor=1, **kwargs)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+
+ lr = canny(rgb2gray(image), sigma=2)
+ lr = lr.astype(np.float32)
+ lr = lr[:,:,None][:,:,[0,0,0]]
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+
+class ImageNetEdgesTrain(ImageNetEdges):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetEdgesValidation(ImageNetEdges):
+ def get_base(self):
+ return ImageNetValidation()
diff --git a/StableSR/taming/data/open_images_helper.py b/StableSR/taming/data/open_images_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8feb7c6e705fc165d2983303192aaa88f579b243
--- /dev/null
+++ b/StableSR/taming/data/open_images_helper.py
@@ -0,0 +1,379 @@
+open_images_unify_categories_for_coco = {
+ '/m/03bt1vf': '/m/01g317',
+ '/m/04yx4': '/m/01g317',
+ '/m/05r655': '/m/01g317',
+ '/m/01bl7v': '/m/01g317',
+ '/m/0cnyhnx': '/m/01xq0k1',
+ '/m/01226z': '/m/018xm',
+ '/m/05ctyq': '/m/018xm',
+ '/m/058qzx': '/m/04ctx',
+ '/m/06pcq': '/m/0l515',
+ '/m/03m3pdh': '/m/02crq1',
+ '/m/046dlr': '/m/01x3z',
+ '/m/0h8mzrc': '/m/01x3z',
+}
+
+
+top_300_classes_plus_coco_compatibility = [
+ ('Man', 1060962),
+ ('Clothing', 986610),
+ ('Tree', 748162),
+ ('Woman', 611896),
+ ('Person', 610294),
+ ('Human face', 442948),
+ ('Girl', 175399),
+ ('Building', 162147),
+ ('Car', 159135),
+ ('Plant', 155704),
+ ('Human body', 137073),
+ ('Flower', 133128),
+ ('Window', 127485),
+ ('Human arm', 118380),
+ ('House', 114365),
+ ('Wheel', 111684),
+ ('Suit', 99054),
+ ('Human hair', 98089),
+ ('Human head', 92763),
+ ('Chair', 88624),
+ ('Boy', 79849),
+ ('Table', 73699),
+ ('Jeans', 57200),
+ ('Tire', 55725),
+ ('Skyscraper', 53321),
+ ('Food', 52400),
+ ('Footwear', 50335),
+ ('Dress', 50236),
+ ('Human leg', 47124),
+ ('Toy', 46636),
+ ('Tower', 45605),
+ ('Boat', 43486),
+ ('Land vehicle', 40541),
+ ('Bicycle wheel', 34646),
+ ('Palm tree', 33729),
+ ('Fashion accessory', 32914),
+ ('Glasses', 31940),
+ ('Bicycle', 31409),
+ ('Furniture', 30656),
+ ('Sculpture', 29643),
+ ('Bottle', 27558),
+ ('Dog', 26980),
+ ('Snack', 26796),
+ ('Human hand', 26664),
+ ('Bird', 25791),
+ ('Book', 25415),
+ ('Guitar', 24386),
+ ('Jacket', 23998),
+ ('Poster', 22192),
+ ('Dessert', 21284),
+ ('Baked goods', 20657),
+ ('Drink', 19754),
+ ('Flag', 18588),
+ ('Houseplant', 18205),
+ ('Tableware', 17613),
+ ('Airplane', 17218),
+ ('Door', 17195),
+ ('Sports uniform', 17068),
+ ('Shelf', 16865),
+ ('Drum', 16612),
+ ('Vehicle', 16542),
+ ('Microphone', 15269),
+ ('Street light', 14957),
+ ('Cat', 14879),
+ ('Fruit', 13684),
+ ('Fast food', 13536),
+ ('Animal', 12932),
+ ('Vegetable', 12534),
+ ('Train', 12358),
+ ('Horse', 11948),
+ ('Flowerpot', 11728),
+ ('Motorcycle', 11621),
+ ('Fish', 11517),
+ ('Desk', 11405),
+ ('Helmet', 10996),
+ ('Truck', 10915),
+ ('Bus', 10695),
+ ('Hat', 10532),
+ ('Auto part', 10488),
+ ('Musical instrument', 10303),
+ ('Sunglasses', 10207),
+ ('Picture frame', 10096),
+ ('Sports equipment', 10015),
+ ('Shorts', 9999),
+ ('Wine glass', 9632),
+ ('Duck', 9242),
+ ('Wine', 9032),
+ ('Rose', 8781),
+ ('Tie', 8693),
+ ('Butterfly', 8436),
+ ('Beer', 7978),
+ ('Cabinetry', 7956),
+ ('Laptop', 7907),
+ ('Insect', 7497),
+ ('Goggles', 7363),
+ ('Shirt', 7098),
+ ('Dairy Product', 7021),
+ ('Marine invertebrates', 7014),
+ ('Cattle', 7006),
+ ('Trousers', 6903),
+ ('Van', 6843),
+ ('Billboard', 6777),
+ ('Balloon', 6367),
+ ('Human nose', 6103),
+ ('Tent', 6073),
+ ('Camera', 6014),
+ ('Doll', 6002),
+ ('Coat', 5951),
+ ('Mobile phone', 5758),
+ ('Swimwear', 5729),
+ ('Strawberry', 5691),
+ ('Stairs', 5643),
+ ('Goose', 5599),
+ ('Umbrella', 5536),
+ ('Cake', 5508),
+ ('Sun hat', 5475),
+ ('Bench', 5310),
+ ('Bookcase', 5163),
+ ('Bee', 5140),
+ ('Computer monitor', 5078),
+ ('Hiking equipment', 4983),
+ ('Office building', 4981),
+ ('Coffee cup', 4748),
+ ('Curtain', 4685),
+ ('Plate', 4651),
+ ('Box', 4621),
+ ('Tomato', 4595),
+ ('Coffee table', 4529),
+ ('Office supplies', 4473),
+ ('Maple', 4416),
+ ('Muffin', 4365),
+ ('Cocktail', 4234),
+ ('Castle', 4197),
+ ('Couch', 4134),
+ ('Pumpkin', 3983),
+ ('Computer keyboard', 3960),
+ ('Human mouth', 3926),
+ ('Christmas tree', 3893),
+ ('Mushroom', 3883),
+ ('Swimming pool', 3809),
+ ('Pastry', 3799),
+ ('Lavender (Plant)', 3769),
+ ('Football helmet', 3732),
+ ('Bread', 3648),
+ ('Traffic sign', 3628),
+ ('Common sunflower', 3597),
+ ('Television', 3550),
+ ('Bed', 3525),
+ ('Cookie', 3485),
+ ('Fountain', 3484),
+ ('Paddle', 3447),
+ ('Bicycle helmet', 3429),
+ ('Porch', 3420),
+ ('Deer', 3387),
+ ('Fedora', 3339),
+ ('Canoe', 3338),
+ ('Carnivore', 3266),
+ ('Bowl', 3202),
+ ('Human eye', 3166),
+ ('Ball', 3118),
+ ('Pillow', 3077),
+ ('Salad', 3061),
+ ('Beetle', 3060),
+ ('Orange', 3050),
+ ('Drawer', 2958),
+ ('Platter', 2937),
+ ('Elephant', 2921),
+ ('Seafood', 2921),
+ ('Monkey', 2915),
+ ('Countertop', 2879),
+ ('Watercraft', 2831),
+ ('Helicopter', 2805),
+ ('Kitchen appliance', 2797),
+ ('Personal flotation device', 2781),
+ ('Swan', 2739),
+ ('Lamp', 2711),
+ ('Boot', 2695),
+ ('Bronze sculpture', 2693),
+ ('Chicken', 2677),
+ ('Taxi', 2643),
+ ('Juice', 2615),
+ ('Cowboy hat', 2604),
+ ('Apple', 2600),
+ ('Tin can', 2590),
+ ('Necklace', 2564),
+ ('Ice cream', 2560),
+ ('Human beard', 2539),
+ ('Coin', 2536),
+ ('Candle', 2515),
+ ('Cart', 2512),
+ ('High heels', 2441),
+ ('Weapon', 2433),
+ ('Handbag', 2406),
+ ('Penguin', 2396),
+ ('Rifle', 2352),
+ ('Violin', 2336),
+ ('Skull', 2304),
+ ('Lantern', 2285),
+ ('Scarf', 2269),
+ ('Saucer', 2225),
+ ('Sheep', 2215),
+ ('Vase', 2189),
+ ('Lily', 2180),
+ ('Mug', 2154),
+ ('Parrot', 2140),
+ ('Human ear', 2137),
+ ('Sandal', 2115),
+ ('Lizard', 2100),
+ ('Kitchen & dining room table', 2063),
+ ('Spider', 1977),
+ ('Coffee', 1974),
+ ('Goat', 1926),
+ ('Squirrel', 1922),
+ ('Cello', 1913),
+ ('Sushi', 1881),
+ ('Tortoise', 1876),
+ ('Pizza', 1870),
+ ('Studio couch', 1864),
+ ('Barrel', 1862),
+ ('Cosmetics', 1841),
+ ('Moths and butterflies', 1841),
+ ('Convenience store', 1817),
+ ('Watch', 1792),
+ ('Home appliance', 1786),
+ ('Harbor seal', 1780),
+ ('Luggage and bags', 1756),
+ ('Vehicle registration plate', 1754),
+ ('Shrimp', 1751),
+ ('Jellyfish', 1730),
+ ('French fries', 1723),
+ ('Egg (Food)', 1698),
+ ('Football', 1697),
+ ('Musical keyboard', 1683),
+ ('Falcon', 1674),
+ ('Candy', 1660),
+ ('Medical equipment', 1654),
+ ('Eagle', 1651),
+ ('Dinosaur', 1634),
+ ('Surfboard', 1630),
+ ('Tank', 1628),
+ ('Grape', 1624),
+ ('Lion', 1624),
+ ('Owl', 1622),
+ ('Ski', 1613),
+ ('Waste container', 1606),
+ ('Frog', 1591),
+ ('Sparrow', 1585),
+ ('Rabbit', 1581),
+ ('Pen', 1546),
+ ('Sea lion', 1537),
+ ('Spoon', 1521),
+ ('Sink', 1512),
+ ('Teddy bear', 1507),
+ ('Bull', 1495),
+ ('Sofa bed', 1490),
+ ('Dragonfly', 1479),
+ ('Brassiere', 1478),
+ ('Chest of drawers', 1472),
+ ('Aircraft', 1466),
+ ('Human foot', 1463),
+ ('Pig', 1455),
+ ('Fork', 1454),
+ ('Antelope', 1438),
+ ('Tripod', 1427),
+ ('Tool', 1424),
+ ('Cheese', 1422),
+ ('Lemon', 1397),
+ ('Hamburger', 1393),
+ ('Dolphin', 1390),
+ ('Mirror', 1390),
+ ('Marine mammal', 1387),
+ ('Giraffe', 1385),
+ ('Snake', 1368),
+ ('Gondola', 1364),
+ ('Wheelchair', 1360),
+ ('Piano', 1358),
+ ('Cupboard', 1348),
+ ('Banana', 1345),
+ ('Trumpet', 1335),
+ ('Lighthouse', 1333),
+ ('Invertebrate', 1317),
+ ('Carrot', 1268),
+ ('Sock', 1260),
+ ('Tiger', 1241),
+ ('Camel', 1224),
+ ('Parachute', 1224),
+ ('Bathroom accessory', 1223),
+ ('Earrings', 1221),
+ ('Headphones', 1218),
+ ('Skirt', 1198),
+ ('Skateboard', 1190),
+ ('Sandwich', 1148),
+ ('Saxophone', 1141),
+ ('Goldfish', 1136),
+ ('Stool', 1104),
+ ('Traffic light', 1097),
+ ('Shellfish', 1081),
+ ('Backpack', 1079),
+ ('Sea turtle', 1078),
+ ('Cucumber', 1075),
+ ('Tea', 1051),
+ ('Toilet', 1047),
+ ('Roller skates', 1040),
+ ('Mule', 1039),
+ ('Bust', 1031),
+ ('Broccoli', 1030),
+ ('Crab', 1020),
+ ('Oyster', 1019),
+ ('Cannon', 1012),
+ ('Zebra', 1012),
+ ('French horn', 1008),
+ ('Grapefruit', 998),
+ ('Whiteboard', 997),
+ ('Zucchini', 997),
+ ('Crocodile', 992),
+
+ ('Clock', 960),
+ ('Wall clock', 958),
+
+ ('Doughnut', 869),
+ ('Snail', 868),
+
+ ('Baseball glove', 859),
+
+ ('Panda', 830),
+ ('Tennis racket', 830),
+
+ ('Pear', 652),
+
+ ('Bagel', 617),
+ ('Oven', 616),
+ ('Ladybug', 615),
+ ('Shark', 615),
+ ('Polar bear', 614),
+ ('Ostrich', 609),
+
+ ('Hot dog', 473),
+ ('Microwave oven', 467),
+ ('Fire hydrant', 20),
+ ('Stop sign', 20),
+ ('Parking meter', 20),
+ ('Bear', 20),
+ ('Flying disc', 20),
+ ('Snowboard', 20),
+ ('Tennis ball', 20),
+ ('Kite', 20),
+ ('Baseball bat', 20),
+ ('Kitchen knife', 20),
+ ('Knife', 20),
+ ('Submarine sandwich', 20),
+ ('Computer mouse', 20),
+ ('Remote control', 20),
+ ('Toaster', 20),
+ ('Sink', 20),
+ ('Refrigerator', 20),
+ ('Alarm clock', 20),
+ ('Wall clock', 20),
+ ('Scissors', 20),
+ ('Hair dryer', 20),
+ ('Toothbrush', 20),
+ ('Suitcase', 20)
+]
diff --git a/StableSR/taming/data/sflckr.py b/StableSR/taming/data/sflckr.py
new file mode 100644
index 0000000000000000000000000000000000000000..91101be5953b113f1e58376af637e43f366b3dee
--- /dev/null
+++ b/StableSR/taming/data/sflckr.py
@@ -0,0 +1,91 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+class SegmentationBase(Dataset):
+ def __init__(self,
+ data_csv, data_root, segmentation_root,
+ size=None, random_crop=False, interpolation="bicubic",
+ n_labels=182, shift_segmentation=False,
+ ):
+ self.n_labels = n_labels
+ self.shift_segmentation = shift_segmentation
+ self.data_csv = data_csv
+ self.data_root = data_root
+ self.segmentation_root = segmentation_root
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, l)
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
+ for l in self.image_paths]
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ assert segmentation.mode == "L", segmentation.mode
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.shift_segmentation:
+ # used to support segmentations containing unlabeled==255 label
+ segmentation = segmentation+1
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image,
+ mask=segmentation
+ )
+ else:
+ processed = {"image": image,
+ "mask": segmentation
+ }
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/sflckr_examples.txt",
+ data_root="data/sflckr_images",
+ segmentation_root="data/sflckr_segmentations",
+ size=size, random_crop=random_crop, interpolation=interpolation)
diff --git a/StableSR/taming/data/utils.py b/StableSR/taming/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b3c3d53cd2b6c72b481b59834cf809d3735b394
--- /dev/null
+++ b/StableSR/taming/data/utils.py
@@ -0,0 +1,169 @@
+import collections
+import os
+import tarfile
+import urllib
+import zipfile
+from pathlib import Path
+
+import numpy as np
+import torch
+from taming.data.helper_types import Annotation
+from torch._six import string_classes
+from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
+from tqdm import tqdm
+
+
+def unpack(path):
+ if path.endswith("tar.gz"):
+ with tarfile.open(path, "r:gz") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("tar"):
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("zip"):
+ with zipfile.ZipFile(path, "r") as f:
+ f.extractall(path=os.path.split(path)[0])
+ else:
+ raise NotImplementedError(
+ "Unknown file extension: {}".format(os.path.splitext(path)[1])
+ )
+
+
+def reporthook(bar):
+ """tqdm progress bar for downloads."""
+
+ def hook(b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ bar.total = tsize
+ bar.update(b * bsize - bar.n)
+
+ return hook
+
+
+def get_root(name):
+ base = "data/"
+ root = os.path.join(base, name)
+ os.makedirs(root, exist_ok=True)
+ return root
+
+
+def is_prepared(root):
+ return Path(root).joinpath(".ready").exists()
+
+
+def mark_prepared(root):
+ Path(root).joinpath(".ready").touch()
+
+
+def prompt_download(file_, source, target_dir, content_dir=None):
+ targetpath = os.path.join(target_dir, file_)
+ while not os.path.exists(targetpath):
+ if content_dir is not None and os.path.exists(
+ os.path.join(target_dir, content_dir)
+ ):
+ break
+ print(
+ "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
+ )
+ if content_dir is not None:
+ print(
+ "Or place its content into '{}'.".format(
+ os.path.join(target_dir, content_dir)
+ )
+ )
+ input("Press Enter when done...")
+ return targetpath
+
+
+def download_url(file_, url, target_dir):
+ targetpath = os.path.join(target_dir, file_)
+ os.makedirs(target_dir, exist_ok=True)
+ with tqdm(
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
+ ) as bar:
+ urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
+ return targetpath
+
+
+def download_urls(urls, target_dir):
+ paths = dict()
+ for fname, url in urls.items():
+ outpath = download_url(fname, url, target_dir)
+ paths[fname] = outpath
+ return paths
+
+
+def quadratic_crop(x, bbox, alpha=1.0):
+ """bbox is xmin, ymin, xmax, ymax"""
+ im_h, im_w = x.shape[:2]
+ bbox = np.array(bbox, dtype=np.float32)
+ bbox = np.clip(bbox, 0, max(im_h, im_w))
+ center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ l = int(alpha * max(w, h))
+ l = max(l, 2)
+
+ required_padding = -1 * min(
+ center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
+ )
+ required_padding = int(np.ceil(required_padding))
+ if required_padding > 0:
+ padding = [
+ [required_padding, required_padding],
+ [required_padding, required_padding],
+ ]
+ padding += [[0, 0]] * (len(x.shape) - 2)
+ x = np.pad(x, padding, "reflect")
+ center = center[0] + required_padding, center[1] + required_padding
+ xmin = int(center[0] - l / 2)
+ ymin = int(center[1] - l / 2)
+ return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
+
+
+def custom_collate(batch):
+ r"""source: pytorch 1.9.0, only one modification to original code """
+
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ storage = elem.storage()._new_shared(numel)
+ out = elem.new(storage)
+ return torch.stack(batch, 0, out=out)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+ return custom_collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, string_classes):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ return {key: custom_collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
+ return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
+ if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
+ return batch # added
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError('each element in list of batch should be of equal size')
+ transposed = zip(*batch)
+ return [custom_collate(samples) for samples in transposed]
+
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
diff --git a/StableSR/taming/lr_scheduler.py b/StableSR/taming/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e598ed120159c53da6820a55ad86b89f5c70c82d
--- /dev/null
+++ b/StableSR/taming/lr_scheduler.py
@@ -0,0 +1,34 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n):
+ return self.schedule(n)
+
diff --git a/StableSR/taming/models/cond_transformer.py b/StableSR/taming/models/cond_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4c63730fa86ac1b92b37af14c14fb696595b1ab
--- /dev/null
+++ b/StableSR/taming/models/cond_transformer.py
@@ -0,0 +1,352 @@
+import os, math
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from main import instantiate_from_config
+from taming.modules.util import SOSProvider
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class Net2NetTransformer(pl.LightningModule):
+ def __init__(self,
+ transformer_config,
+ first_stage_config,
+ cond_stage_config,
+ permuter_config=None,
+ ckpt_path=None,
+ ignore_keys=[],
+ first_stage_key="image",
+ cond_stage_key="depth",
+ downsample_cond_size=-1,
+ pkeep=1.0,
+ sos_token=0,
+ unconditional=False,
+ ):
+ super().__init__()
+ self.be_unconditional = unconditional
+ self.sos_token = sos_token
+ self.first_stage_key = first_stage_key
+ self.cond_stage_key = cond_stage_key
+ self.init_first_stage_from_ckpt(first_stage_config)
+ self.init_cond_stage_from_ckpt(cond_stage_config)
+ if permuter_config is None:
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
+ self.permuter = instantiate_from_config(config=permuter_config)
+ self.transformer = instantiate_from_config(config=transformer_config)
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.downsample_cond_size = downsample_cond_size
+ self.pkeep = pkeep
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ for k in sd.keys():
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ self.print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def init_first_stage_from_ckpt(self, config):
+ model = instantiate_from_config(config)
+ model = model.eval()
+ model.train = disabled_train
+ self.first_stage_model = model
+
+ def init_cond_stage_from_ckpt(self, config):
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__" or self.be_unconditional:
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
+ f"Prepending {self.sos_token} as a sos token.")
+ self.be_unconditional = True
+ self.cond_stage_key = self.first_stage_key
+ self.cond_stage_model = SOSProvider(self.sos_token)
+ else:
+ model = instantiate_from_config(config)
+ model = model.eval()
+ model.train = disabled_train
+ self.cond_stage_model = model
+
+ def forward(self, x, c):
+ # one step to produce the logits
+ _, z_indices = self.encode_to_z(x)
+ _, c_indices = self.encode_to_c(c)
+
+ if self.training and self.pkeep < 1.0:
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
+ device=z_indices.device))
+ mask = mask.round().to(dtype=torch.int64)
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
+ a_indices = mask*z_indices+(1-mask)*r_indices
+ else:
+ a_indices = z_indices
+
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
+
+ # target includes all sequence elements (no need to handle first one
+ # differently because we are conditioning)
+ target = z_indices
+ # make the prediction
+ logits, _ = self.transformer(cz_indices[:, :-1])
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1:
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
+ if len(indices.shape) > 2:
+ indices = indices.view(c.shape[0], -1)
+ return quant_c, indices
+
+ @torch.no_grad()
+ def decode_to_img(self, index, zshape):
+ index = self.permuter(index, reverse=True)
+ bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(
+ index.reshape(-1), shape=bhwc)
+ x = self.first_stage_model.decode(quant_z)
+ return x
+
+ @torch.no_grad()
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
+ log = dict()
+
+ N = 4
+ if lr_interface:
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
+ else:
+ x, c = self.get_xc(batch, N)
+ x = x.to(device=self.device)
+ c = c.to(device=self.device)
+
+ quant_z, z_indices = self.encode_to_z(x)
+ quant_c, c_indices = self.encode_to_c(c)
+
+ # create a "half"" sample
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
+
+ # sample
+ z_start_indices = z_indices[:, :0]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
+
+ # det sample
+ z_start_indices = z_indices[:, :0]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1],
+ sample=False,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
+
+ # reconstruction
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
+
+ log["inputs"] = x
+ log["reconstructions"] = x_rec
+
+ if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
+ figure_size = (x_rec.shape[2], x_rec.shape[3])
+ dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
+ label_for_category_no = dataset.get_textual_label_for_category_no
+ plotter = dataset.conditional_builders[self.cond_stage_key].plot
+ log["conditioning"] = torch.zeros_like(log["reconstructions"])
+ for i in range(quant_c.shape[0]):
+ log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
+ log["conditioning_rec"] = log["conditioning"]
+ elif self.cond_stage_key != "image":
+ cond_rec = self.cond_stage_model.decode(quant_c)
+ if self.cond_stage_key == "segmentation":
+ # get image from segmentation mask
+ num_classes = cond_rec.shape[1]
+
+ c = torch.argmax(c, dim=1, keepdim=True)
+ c = F.one_hot(c, num_classes=num_classes)
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
+ c = self.cond_stage_model.to_rgb(c)
+
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
+ log["conditioning_rec"] = cond_rec
+ log["conditioning"] = c
+
+ log["samples_half"] = x_sample
+ log["samples_nopix"] = x_sample_nopix
+ log["samples_det"] = x_sample_det
+ return log
+
+ def get_input(self, key, batch):
+ x = batch[key]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ if len(x.shape) == 4:
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ if x.dtype == torch.double:
+ x = x.float()
+ return x
+
+ def get_xc(self, batch, N=None):
+ x = self.get_input(self.first_stage_key, batch)
+ c = self.get_input(self.cond_stage_key, batch)
+ if N is not None:
+ x = x[:N]
+ c = c[:N]
+ return x, c
+
+ def shared_step(self, batch, batch_idx):
+ x, c = self.get_xc(batch)
+ logits, target = self(x, c)
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
+ return loss
+
+ def training_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def configure_optimizers(self):
+ """
+ Following minGPT:
+ This long function is unfortunately doing something very simple and is being very defensive:
+ We are separating out all parameters of the model into two buckets: those that will experience
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
+ We are then returning the PyTorch optimizer object.
+ """
+ # separate out all parameters to those that will and won't experience regularizing weight decay
+ decay = set()
+ no_decay = set()
+ whitelist_weight_modules = (torch.nn.Linear, )
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in self.transformer.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+ if pn.endswith('bias'):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+
+ # special case the position embedding parameter in the root GPT module as not decayed
+ no_decay.add('pos_emb')
+
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
+ inter_params = decay & no_decay
+ union_params = decay | no_decay
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
+ % (str(param_dict.keys() - union_params), )
+
+ # create the pytorch optimizer object
+ optim_groups = [
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+ ]
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
+ return optimizer
diff --git a/StableSR/taming/models/dummy_cond_stage.py b/StableSR/taming/models/dummy_cond_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e19938078752e09b926a3e749907ee99a258ca0
--- /dev/null
+++ b/StableSR/taming/models/dummy_cond_stage.py
@@ -0,0 +1,22 @@
+from torch import Tensor
+
+
+class DummyCondStage:
+ def __init__(self, conditional_key):
+ self.conditional_key = conditional_key
+ self.train = None
+
+ def eval(self):
+ return self
+
+ @staticmethod
+ def encode(c: Tensor):
+ return c, None, (None, None, c)
+
+ @staticmethod
+ def decode(c: Tensor):
+ return c
+
+ @staticmethod
+ def to_rgb(c: Tensor):
+ return c
diff --git a/StableSR/taming/models/vqgan.py b/StableSR/taming/models/vqgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6950baa5f739111cd64c17235dca8be3a5f8037
--- /dev/null
+++ b/StableSR/taming/models/vqgan.py
@@ -0,0 +1,404 @@
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from main import instantiate_from_config
+
+from taming.modules.diffusionmodules.model import Encoder, Decoder
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+from taming.modules.vqvae.quantize import GumbelQuantize
+from taming.modules.vqvae.quantize import EMAVectorQuantizer
+
+class VQModel(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+ remap=remap, sane_index_shape=sane_index_shape)
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.image_key = image_key
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input):
+ quant, diff, _ = self.encode(input)
+ dec = self.decode(quant)
+ return dec, diff
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ return x.float()
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ self.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ self.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class VQSegmentationModel(VQModel):
+ def __init__(self, n_labels, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ return opt_ae
+
+ def training_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ total_loss = log_dict_ae["val/total_loss"]
+ self.log("val/total_loss", total_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ return aeloss
+
+ @torch.no_grad()
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ # convert logits to indices
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ return log
+
+
+class VQNoDiscModel(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None
+ ):
+ super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
+ ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
+ colorize_nlabels=colorize_nlabels)
+
+ def training_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
+ output = pl.TrainResult(minimize=aeloss)
+ output.log("train/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return output
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ output = pl.EvalResult(checkpoint_on=rec_loss)
+ output.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log_dict(log_dict_ae)
+
+ return output
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=self.learning_rate, betas=(0.5, 0.9))
+ return optimizer
+
+
+class GumbelVQ(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ temperature_scheduler_config,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ kl_weight=1e-8,
+ remap=None,
+ ):
+
+ z_channels = ddconfig["z_channels"]
+ super().__init__(ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=ignore_keys,
+ image_key=image_key,
+ colorize_nlabels=colorize_nlabels,
+ monitor=monitor,
+ )
+
+ self.loss.n_classes = n_embed
+ self.vocab_size = n_embed
+
+ self.quantize = GumbelQuantize(z_channels, embed_dim,
+ n_embed=n_embed,
+ kl_weight=kl_weight, temp_init=1.0,
+ remap=remap)
+
+ self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def temperature_scheduling(self):
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
+
+ def encode_to_prequant(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode_code(self, code_b):
+ raise NotImplementedError
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ self.temperature_scheduling()
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x, return_pred_indices=True)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ self.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ # encode
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, _, _ = self.quantize(h)
+ # decode
+ x_rec = self.decode(quant)
+ log["inputs"] = x
+ log["reconstructions"] = x_rec
+ return log
+
+
+class EMAVQ(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ ):
+ super().__init__(ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=ignore_keys,
+ image_key=image_key,
+ colorize_nlabels=colorize_nlabels,
+ monitor=monitor,
+ )
+ self.quantize = EMAVectorQuantizer(n_embed=n_embed,
+ embedding_dim=embed_dim,
+ beta=0.25,
+ remap=remap)
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ #Remove self.quantize from parameter list since it is updated via EMA
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
\ No newline at end of file
diff --git a/StableSR/taming/modules/diffusionmodules/model.py b/StableSR/taming/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a5db6aa2ef915e270f1ae135e4a9918fdd884c
--- /dev/null
+++ b/StableSR/taming/modules/diffusionmodules/model.py
@@ -0,0 +1,776 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, t=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x):
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class VUNet(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ in_channels, c_channels,
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(c_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ self.z_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, z):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ z = self.z_in(z)
+ h = torch.cat((h,z),dim=1)
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
diff --git a/StableSR/taming/modules/discriminator/model.py b/StableSR/taming/modules/discriminator/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aaa3110d0a7bcd05de7eca1e45101589ca5af05
--- /dev/null
+++ b/StableSR/taming/modules/discriminator/model.py
@@ -0,0 +1,67 @@
+import functools
+import torch.nn as nn
+
+
+from taming.modules.util import ActNorm
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find('BatchNorm') != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
diff --git a/StableSR/taming/modules/losses/__init__.py b/StableSR/taming/modules/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d09caf9eb805f849a517f1b23503e1a4d6ea1ec5
--- /dev/null
+++ b/StableSR/taming/modules/losses/__init__.py
@@ -0,0 +1,2 @@
+from taming.modules.losses.vqperceptual import DummyLoss
+
diff --git a/StableSR/taming/modules/losses/lpips.py b/StableSR/taming/modules/losses/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7280447694ffc302a7636e7e4d6183408e0aa95
--- /dev/null
+++ b/StableSR/taming/modules/losses/lpips.py
@@ -0,0 +1,123 @@
+"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
+
+import torch
+import torch.nn as nn
+from torchvision import models
+from collections import namedtuple
+
+from taming.util import get_ckpt_path
+
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """ A single linear layer which does a 1x1 conv """
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = [nn.Dropout(), ] if (use_dropout) else []
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x,eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
+ return x/(norm_factor+eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2,3],keepdim=keepdim)
+
diff --git a/StableSR/taming/modules/losses/segmentation.py b/StableSR/taming/modules/losses/segmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ba77deb5159a6307ed2acba9945e4764a4ff0a5
--- /dev/null
+++ b/StableSR/taming/modules/losses/segmentation.py
@@ -0,0 +1,22 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BCELoss(nn.Module):
+ def forward(self, prediction, target):
+ loss = F.binary_cross_entropy_with_logits(prediction,target)
+ return loss, {}
+
+
+class BCELossWithQuant(nn.Module):
+ def __init__(self, codebook_weight=1.):
+ super().__init__()
+ self.codebook_weight = codebook_weight
+
+ def forward(self, qloss, target, prediction, split):
+ bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
+ loss = bce_loss + self.codebook_weight*qloss
+ return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
+ "{}/quant_loss".format(split): qloss.detach().mean()
+ }
diff --git a/StableSR/taming/modules/losses/vqperceptual.py b/StableSR/taming/modules/losses/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2febd445728479d4cd9aacdb2572cb1f1af04db
--- /dev/null
+++ b/StableSR/taming/modules/losses/vqperceptual.py
@@ -0,0 +1,136 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+
+
+class DummyLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1. - logits_real))
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
+ return d_loss
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train"):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/StableSR/taming/modules/misc/coord.py b/StableSR/taming/modules/misc/coord.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee69b0c897b6b382ae673622e420f55e494f5b09
--- /dev/null
+++ b/StableSR/taming/modules/misc/coord.py
@@ -0,0 +1,31 @@
+import torch
+
+class CoordStage(object):
+ def __init__(self, n_embed, down_factor):
+ self.n_embed = n_embed
+ self.down_factor = down_factor
+
+ def eval(self):
+ return self
+
+ def encode(self, c):
+ """fake vqmodel interface"""
+ assert 0.0 <= c.min() and c.max() <= 1.0
+ b,ch,h,w = c.shape
+ assert ch == 1
+
+ c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
+ mode="area")
+ c = c.clamp(0.0, 1.0)
+ c = self.n_embed*c
+ c_quant = c.round()
+ c_ind = c_quant.to(dtype=torch.long)
+
+ info = None, None, c_ind
+ return c_quant, None, info
+
+ def decode(self, c):
+ c = c/self.n_embed
+ c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
+ mode="nearest")
+ return c
diff --git a/StableSR/taming/modules/transformer/mingpt.py b/StableSR/taming/modules/transformer/mingpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d14b7b68117f4b9f297b2929397cd4f55089334c
--- /dev/null
+++ b/StableSR/taming/modules/transformer/mingpt.py
@@ -0,0 +1,415 @@
+"""
+taken from: https://github.com/karpathy/minGPT/
+GPT model:
+- the initial stem consists of a combination of token encoding and a positional encoding
+- the meat of it is a uniform sequence of Transformer blocks
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
+ - all blocks feed into a central residual pathway similar to resnets
+- the final decoder is a linear projection into a vanilla Softmax classifier
+"""
+
+import math
+import logging
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from transformers import top_k_top_p_filtering
+
+logger = logging.getLogger(__name__)
+
+
+class GPTConfig:
+ """ base GPT config, params common to all GPT versions """
+ embd_pdrop = 0.1
+ resid_pdrop = 0.1
+ attn_pdrop = 0.1
+
+ def __init__(self, vocab_size, block_size, **kwargs):
+ self.vocab_size = vocab_size
+ self.block_size = block_size
+ for k,v in kwargs.items():
+ setattr(self, k, v)
+
+
+class GPT1Config(GPTConfig):
+ """ GPT-1 like network roughly 125M params """
+ n_layer = 12
+ n_head = 12
+ n_embd = 768
+
+
+class CausalSelfAttention(nn.Module):
+ """
+ A vanilla multi-head masked self-attention layer with a projection at the end.
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
+ explicit implementation here to show that there is nothing too scary here.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ assert config.n_embd % config.n_head == 0
+ # key, query, value projections for all heads
+ self.key = nn.Linear(config.n_embd, config.n_embd)
+ self.query = nn.Linear(config.n_embd, config.n_embd)
+ self.value = nn.Linear(config.n_embd, config.n_embd)
+ # regularization
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
+ # output projection
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ mask = torch.tril(torch.ones(config.block_size,
+ config.block_size))
+ if hasattr(config, "n_unmasked"):
+ mask[:config.n_unmasked, :config.n_unmasked] = 1
+ self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
+ self.n_head = config.n_head
+
+ def forward(self, x, layer_past=None):
+ B, T, C = x.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+
+ present = torch.stack((k, v))
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ k = torch.cat((past_key, k), dim=-2)
+ v = torch.cat((past_value, v), dim=-2)
+
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+ if layer_past is None:
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
+
+ att = F.softmax(att, dim=-1)
+ att = self.attn_drop(att)
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
+
+ # output projection
+ y = self.resid_drop(self.proj(y))
+ return y, present # TODO: check that this does not break anything
+
+
+class Block(nn.Module):
+ """ an unassuming Transformer block """
+ def __init__(self, config):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.attn = CausalSelfAttention(config)
+ self.mlp = nn.Sequential(
+ nn.Linear(config.n_embd, 4 * config.n_embd),
+ nn.GELU(), # nice
+ nn.Linear(4 * config.n_embd, config.n_embd),
+ nn.Dropout(config.resid_pdrop),
+ )
+
+ def forward(self, x, layer_past=None, return_present=False):
+ # TODO: check that training still works
+ if return_present: assert not self.training
+ # layer past: tuple of length two with B, nh, T, hs
+ attn, present = self.attn(self.ln1(x), layer_past=layer_past)
+
+ x = x + attn
+ x = x + self.mlp(self.ln2(x))
+ if layer_past is not None or return_present:
+ return x, present
+ return x
+
+
+class GPT(nn.Module):
+ """ the full GPT language model, with a context size of block_size """
+ def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+ super().__init__()
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+ n_unmasked=n_unmasked)
+ # input embedding stem
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.block_size = config.block_size
+ self.apply(self._init_weights)
+ self.config = config
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, embeddings=None, targets=None):
+ # forward the GPT model
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.ln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
+
+ def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
+ # inference only
+ assert not self.training
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ if past is not None:
+ assert past_length is not None
+ past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
+ past_shape = list(past.shape)
+ expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
+ assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
+ position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
+ else:
+ position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
+
+ x = self.drop(token_embeddings + position_embeddings)
+ presents = [] # accumulate over layers
+ for i, block in enumerate(self.blocks):
+ x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
+ presents.append(present)
+
+ x = self.ln_f(x)
+ logits = self.head(x)
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
+
+
+class DummyGPT(nn.Module):
+ # for debugging
+ def __init__(self, add_value=1):
+ super().__init__()
+ self.add_value = add_value
+
+ def forward(self, idx):
+ return idx + self.add_value, None
+
+
+class CodeGPT(nn.Module):
+ """Takes in semi-embeddings"""
+ def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+ super().__init__()
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+ n_unmasked=n_unmasked)
+ # input embedding stem
+ self.tok_emb = nn.Linear(in_channels, config.n_embd)
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.block_size = config.block_size
+ self.apply(self._init_weights)
+ self.config = config
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, embeddings=None, targets=None):
+ # forward the GPT model
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.taming_cinln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
+
+
+
+#### sampling utils
+
+def top_k_logits(logits, k):
+ v, ix = torch.topk(logits, k)
+ out = logits.clone()
+ out[out < v[:, [-1]]] = -float('Inf')
+ return out
+
+@torch.no_grad()
+def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
+ """
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
+ of block_size, unlike an RNN that has an infinite context window.
+ """
+ block_size = model.get_block_size()
+ model.eval()
+ for k in range(steps):
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
+ logits, _ = model(x_cond)
+ # pluck the logits at the final step and scale by temperature
+ logits = logits[:, -1, :] / temperature
+ # optionally crop probabilities to only the top k options
+ if top_k is not None:
+ logits = top_k_logits(logits, top_k)
+ # apply softmax to convert to probabilities
+ probs = F.softmax(logits, dim=-1)
+ # sample from the distribution or take the most likely
+ if sample:
+ ix = torch.multinomial(probs, num_samples=1)
+ else:
+ _, ix = torch.topk(probs, k=1, dim=-1)
+ # append to the sequence and continue
+ x = torch.cat((x, ix), dim=1)
+
+ return x
+
+
+@torch.no_grad()
+def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
+ top_k=None, top_p=None, callback=None):
+ # x is conditioning
+ sample = x
+ cond_len = x.shape[1]
+ past = None
+ for n in range(steps):
+ if callback is not None:
+ callback(n)
+ logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
+ if past is None:
+ past = [present]
+ else:
+ past.append(present)
+ logits = logits[:, -1, :] / temperature
+ if top_k is not None:
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+
+ probs = F.softmax(logits, dim=-1)
+ if not sample_logits:
+ _, x = torch.topk(probs, k=1, dim=-1)
+ else:
+ x = torch.multinomial(probs, num_samples=1)
+ # append to the sequence and continue
+ sample = torch.cat((sample, x), dim=1)
+ del past
+ sample = sample[:, cond_len:] # cut conditioning off
+ return sample
+
+
+#### clustering utils
+
+class KMeans(nn.Module):
+ def __init__(self, ncluster=512, nc=3, niter=10):
+ super().__init__()
+ self.ncluster = ncluster
+ self.nc = nc
+ self.niter = niter
+ self.shape = (3,32,32)
+ self.register_buffer("C", torch.zeros(self.ncluster,nc))
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def is_initialized(self):
+ return self.initialized.item() == 1
+
+ @torch.no_grad()
+ def initialize(self, x):
+ N, D = x.shape
+ assert D == self.nc, D
+ c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
+ for i in range(self.niter):
+ # assign all pixels to the closest codebook element
+ a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
+ # move each codebook element to be the mean of the pixels that assigned to it
+ c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
+ # re-assign any poorly positioned codebook elements
+ nanix = torch.any(torch.isnan(c), dim=1)
+ ndead = nanix.sum().item()
+ print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
+ c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
+
+ self.C.copy_(c)
+ self.initialized.fill_(1)
+
+
+ def forward(self, x, reverse=False, shape=None):
+ if not reverse:
+ # flatten
+ bs,c,h,w = x.shape
+ assert c == self.nc
+ x = x.reshape(bs,c,h*w,1)
+ C = self.C.permute(1,0)
+ C = C.reshape(1,c,1,self.ncluster)
+ a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
+ return a
+ else:
+ # flatten
+ bs, HW = x.shape
+ """
+ c = self.C.reshape( 1, self.nc, 1, self.ncluster)
+ c = c[bs*[0],:,:,:]
+ c = c[:,:,HW*[0],:]
+ x = x.reshape(bs, 1, HW, 1)
+ x = x[:,3*[0],:,:]
+ x = torch.gather(c, dim=3, index=x)
+ """
+ x = self.C[x]
+ x = x.permute(0,2,1)
+ shape = shape if shape is not None else self.shape
+ x = x.reshape(bs, *shape)
+
+ return x
diff --git a/StableSR/taming/modules/transformer/permuter.py b/StableSR/taming/modules/transformer/permuter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d43bb135adde38d94bf18a7e5edaa4523cd95cf
--- /dev/null
+++ b/StableSR/taming/modules/transformer/permuter.py
@@ -0,0 +1,248 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class AbstractPermuter(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ def forward(self, x, reverse=False):
+ raise NotImplementedError
+
+
+class Identity(AbstractPermuter):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, reverse=False):
+ return x
+
+
+class Subsample(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ C = 1
+ indices = np.arange(H*W).reshape(C,H,W)
+ while min(H, W) > 1:
+ indices = indices.reshape(C,H//2,2,W//2,2)
+ indices = indices.transpose(0,2,4,1,3)
+ indices = indices.reshape(C*4,H//2, W//2)
+ H = H//2
+ W = W//2
+ C = C*4
+ assert H == W == 1
+ idx = torch.tensor(indices.ravel())
+ self.register_buffer('forward_shuffle_idx',
+ nn.Parameter(idx, requires_grad=False))
+ self.register_buffer('backward_shuffle_idx',
+ nn.Parameter(torch.argsort(idx), requires_grad=False))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+def mortonify(i, j):
+ """(i,j) index to linear morton code"""
+ i = np.uint64(i)
+ j = np.uint64(j)
+
+ z = np.uint(0)
+
+ for pos in range(32):
+ z = (z |
+ ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
+ ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
+ )
+ return z
+
+
+class ZCurve(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
+ idx = np.argsort(reverseidx)
+ idx = torch.tensor(idx)
+ reverseidx = torch.tensor(reverseidx)
+ self.register_buffer('forward_shuffle_idx',
+ idx)
+ self.register_buffer('backward_shuffle_idx',
+ reverseidx)
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class SpiralOut(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ assert H == W
+ size = W
+ indices = np.arange(size*size).reshape(size,size)
+
+ i0 = size//2
+ j0 = size//2-1
+
+ i = i0
+ j = j0
+
+ idx = [indices[i0, j0]]
+ step_mult = 0
+ for c in range(1, size//2+1):
+ step_mult += 1
+ # steps left
+ for k in range(step_mult):
+ i = i - 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step down
+ for k in range(step_mult):
+ i = i
+ j = j + 1
+ idx.append(indices[i, j])
+
+ step_mult += 1
+ if c < size//2:
+ # step right
+ for k in range(step_mult):
+ i = i + 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step up
+ for k in range(step_mult):
+ i = i
+ j = j - 1
+ idx.append(indices[i, j])
+ else:
+ # end reached
+ for k in range(step_mult-1):
+ i = i + 1
+ idx.append(indices[i, j])
+
+ assert len(idx) == size*size
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class SpiralIn(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ assert H == W
+ size = W
+ indices = np.arange(size*size).reshape(size,size)
+
+ i0 = size//2
+ j0 = size//2-1
+
+ i = i0
+ j = j0
+
+ idx = [indices[i0, j0]]
+ step_mult = 0
+ for c in range(1, size//2+1):
+ step_mult += 1
+ # steps left
+ for k in range(step_mult):
+ i = i - 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step down
+ for k in range(step_mult):
+ i = i
+ j = j + 1
+ idx.append(indices[i, j])
+
+ step_mult += 1
+ if c < size//2:
+ # step right
+ for k in range(step_mult):
+ i = i + 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step up
+ for k in range(step_mult):
+ i = i
+ j = j - 1
+ idx.append(indices[i, j])
+ else:
+ # end reached
+ for k in range(step_mult-1):
+ i = i + 1
+ idx.append(indices[i, j])
+
+ assert len(idx) == size*size
+ idx = idx[::-1]
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class Random(nn.Module):
+ def __init__(self, H, W):
+ super().__init__()
+ indices = np.random.RandomState(1).permutation(H*W)
+ idx = torch.tensor(indices.ravel())
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class AlternateParsing(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ indices = np.arange(W*H).reshape(H,W)
+ for i in range(1, H, 2):
+ indices[i, :] = indices[i, ::-1]
+ idx = indices.flatten()
+ assert len(idx) == H*W
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+if __name__ == "__main__":
+ p0 = AlternateParsing(16, 16)
+ print(p0.forward_shuffle_idx)
+ print(p0.backward_shuffle_idx)
+
+ x = torch.randint(0, 768, size=(11, 256))
+ y = p0(x)
+ xre = p0(y, reverse=True)
+ assert torch.equal(x, xre)
+
+ p1 = SpiralOut(2, 2)
+ print(p1.forward_shuffle_idx)
+ print(p1.backward_shuffle_idx)
diff --git a/StableSR/taming/modules/util.py b/StableSR/taming/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee16385d8b1342a2d60a5f1aa5cadcfbe934bd8
--- /dev/null
+++ b/StableSR/taming/modules/util.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+
+
+def count_params(model):
+ total_params = sum(p.numel() for p in model.parameters())
+ return total_params
+
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True,
+ allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = (
+ flatten.mean(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+ std = (
+ flatten.std(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height*width*torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class Labelator(AbstractEncoder):
+ """Net2Net Interface for Class-Conditional Model"""
+ def __init__(self, n_classes, quantize_interface=True):
+ super().__init__()
+ self.n_classes = n_classes
+ self.quantize_interface = quantize_interface
+
+ def encode(self, c):
+ c = c[:,None]
+ if self.quantize_interface:
+ return c, None, [None, None, c.long()]
+ return c
+
+
+class SOSProvider(AbstractEncoder):
+ # for unconditional training
+ def __init__(self, sos_token, quantize_interface=True):
+ super().__init__()
+ self.sos_token = sos_token
+ self.quantize_interface = quantize_interface
+
+ def encode(self, x):
+ # get batch size from data and replicate sos_token
+ c = torch.ones(x.shape[0], 1)*self.sos_token
+ c = c.long().to(x.device)
+ if self.quantize_interface:
+ return c, None, [None, None, c]
+ return c
diff --git a/StableSR/taming/modules/vqvae/quantize.py b/StableSR/taming/modules/vqvae/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..d75544e41fa01bce49dd822b1037963d62f79b51
--- /dev/null
+++ b/StableSR/taming/modules/vqvae/quantize.py
@@ -0,0 +1,445 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
+ # used wherever VectorQuantizer has been used before and is additionally
+ # more efficient.
+ def __init__(self, n_e, e_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ def forward(self, z):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete
+ one-hot vector that is the index of the closest embedding vector e_j
+ z (continuous) -> z_q (discrete)
+ z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.matmul(z_flattened, self.embedding.weight.t())
+
+ ## could possible replace this here
+ # #\start...
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+
+ min_encodings = torch.zeros(
+ min_encoding_indices.shape[0], self.n_e).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # dtype min encodings: torch.float32
+ # min_encodings shape: torch.Size([2048, 512])
+ # min_encoding_indices.shape: torch.Size([2048, 1])
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ #.........\end
+
+ # with:
+ # .........\start
+ #min_encoding_indices = torch.argmin(d, dim=1)
+ #z_q = self.embedding(min_encoding_indices)
+ # ......\end......... (TODO)
+
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ # TODO: check for more easy handling with nn.Embedding
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
+ min_encodings.scatter_(1, indices[:,None], 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantize(nn.Module):
+ """
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
+ remap=None, unknown_index="random"):
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.use_vqinterface = use_vqinterface
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, return_logits=False):
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:,self.used,...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:,self.used,...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+
+ ind = soft_one_hot.argmax(dim=1)
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+ if self.use_vqinterface:
+ if return_logits:
+ return z_q, diff, (None, None, ind), logits
+ return z_q, diff, (None, None, ind)
+ return z_q, diff, ind
+
+ def get_codebook_entry(self, indices, shape):
+ b, h, w, c = shape
+ assert b*h*w == indices.shape[0]
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
+ sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
+ assert return_logits==False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0],-1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
+ super().__init__()
+ self.decay = decay
+ self.eps = eps
+ weight = torch.randn(num_tokens, codebook_dim)
+ self.weight = nn.Parameter(weight, requires_grad = False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
+ self.update = True
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ )
+ #normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ self.weight.data.copy_(embed_normalized)
+
+
+class EMAVectorQuantizer(nn.Module):
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
+ remap=None, unknown_index="random"):
+ super().__init__()
+ self.codebook_dim = codebook_dim
+ self.num_tokens = num_tokens
+ self.beta = beta
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ #z, 'b c h w -> b h w c'
+ z = rearrange(z, 'b c h w -> b h w c')
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
+
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ if self.training and self.embedding.update:
+ #EMA cluster size
+ encodings_sum = encodings.sum(0)
+ self.embedding.cluster_size_ema_update(encodings_sum)
+ #EMA embedding average
+ embed_sum = encodings.transpose(0,1) @ z_flattened
+ self.embedding.embed_avg_ema_update(embed_sum)
+ #normalize embed_avg and update weight
+ self.embedding.weight_update(self.num_tokens)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ #z_q, 'b h w c -> b c h w'
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
+ return z_q, loss, (perplexity, encodings, encoding_indices)
diff --git a/StableSR/taming/util.py b/StableSR/taming/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7
--- /dev/null
+++ b/StableSR/taming/util.py
@@ -0,0 +1,157 @@
+import os, hashlib
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
+}
+
+CKPT_MAP = {
+ "vgg_lpips": "vgg.pth"
+}
+
+MD5_MAP = {
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
+}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class KeyNotFoundError(Exception):
+ def __init__(self, cause, keys=None, visited=None):
+ self.cause = cause
+ self.keys = keys
+ self.visited = visited
+ messages = list()
+ if keys is not None:
+ messages.append("Key not found: {}".format(keys))
+ if visited is not None:
+ messages.append("Visited: {}".format(visited))
+ messages.append("Cause:\n{}".format(cause))
+ message = "\n".join(messages)
+ super().__init__(message)
+
+
+def retrieve(
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
+):
+ """Given a nested list or dict return the desired value at key expanding
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
+ is done in-place.
+
+ Parameters
+ ----------
+ list_or_dict : list or dict
+ Possibly nested list or dictionary.
+ key : str
+ key/to/value, path like string describing all keys necessary to
+ consider to get to the desired value. List indices can also be
+ passed here.
+ splitval : str
+ String that defines the delimiter between keys of the
+ different depth levels in `key`.
+ default : obj
+ Value returned if :attr:`key` is not found.
+ expand : bool
+ Whether to expand callable nodes on the path or not.
+
+ Returns
+ -------
+ The desired value or if :attr:`default` is not ``None`` and the
+ :attr:`key` is not found returns ``default``.
+
+ Raises
+ ------
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
+ ``None``.
+ """
+
+ keys = key.split(splitval)
+
+ success = True
+ try:
+ visited = []
+ parent = None
+ last_key = None
+ for key in keys:
+ if callable(list_or_dict):
+ if not expand:
+ raise KeyNotFoundError(
+ ValueError(
+ "Trying to get past callable node with expand=False."
+ ),
+ keys=keys,
+ visited=visited,
+ )
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+
+ last_key = key
+ parent = list_or_dict
+
+ try:
+ if isinstance(list_or_dict, dict):
+ list_or_dict = list_or_dict[key]
+ else:
+ list_or_dict = list_or_dict[int(key)]
+ except (KeyError, IndexError, ValueError) as e:
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
+
+ visited += [key]
+ # final expansion of retrieved value
+ if expand and callable(list_or_dict):
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+ except KeyNotFoundError as e:
+ if default is None:
+ raise e
+ else:
+ list_or_dict = default
+ success = False
+
+ if not pass_success:
+ return list_or_dict
+ else:
+ return list_or_dict, success
+
+
+if __name__ == "__main__":
+ config = {"keya": "a",
+ "keyb": "b",
+ "keyc":
+ {"cc1": 1,
+ "cc2": 2,
+ }
+ }
+ from omegaconf import OmegaConf
+ config = OmegaConf.create(config)
+ print(config)
+ retrieve(config, "keya")
+
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..188a716235f705cfc2e2ffa7bfc3d4f2175608db
--- /dev/null
+++ b/app.py
@@ -0,0 +1,363 @@
+"""
+This file is used for deploying hugging face demo:
+https://huggingface.co/spaces/
+"""
+
+import sys
+sys.path.append('StableSR')
+import os
+import cv2
+import torch
+import torch.nn.functional as F
+import gradio as gr
+import torchvision
+from torchvision.transforms.functional import normalize
+from ldm.util import instantiate_from_config
+from torch import autocast
+import PIL
+import numpy as np
+from pytorch_lightning import seed_everything
+from contextlib import nullcontext
+from omegaconf import OmegaConf
+from PIL import Image
+import copy
+from scripts.wavelet_color_fix import wavelet_reconstruction, adaptive_instance_normalization
+from scripts.util_image import ImageSpliterTh
+from basicsr.utils.download_util import load_file_from_url
+from einops import rearrange, repeat
+
+# os.system("pip freeze")
+
+pretrain_model_url = {
+ 'stablesr_512': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_000117.ckpt',
+ 'stablesr_768': 'https://huggingface.co/Iceclear/StableSR/resolve/main/stablesr_768v_000139.ckpt',
+ 'CFW': 'https://huggingface.co/Iceclear/StableSR/resolve/main/vqgan_cfw_00011.ckpt',
+}
+# download weights
+if not os.path.exists('./stablesr_000117.ckpt'):
+ load_file_from_url(url=pretrain_model_url['stablesr_512'], model_dir='./', progress=True, file_name=None)
+if not os.path.exists('./stablesr_768v_000139.ckpt'):
+ load_file_from_url(url=pretrain_model_url['stablesr_768'], model_dir='./', progress=True, file_name=None)
+if not os.path.exists('./vqgan_cfw_00011.ckpt'):
+ load_file_from_url(url=pretrain_model_url['CFW'], model_dir='./', progress=True, file_name=None)
+
+# download images
+torch.hub.download_url_to_file(
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/Lincoln.png',
+ '01.png')
+torch.hub.download_url_to_file(
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/oldphoto6.png',
+ '02.png')
+torch.hub.download_url_to_file(
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/comic2.png',
+ '03.png')
+torch.hub.download_url_to_file(
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet128/OST_120.png',
+ '04.png')
+torch.hub.download_url_to_file(
+ 'https://raw.githubusercontent.com/zsyOAOA/ResShift/master/testdata/RealSet65/comic3.png',
+ '05.png')
+
+def load_img(path):
+ image = Image.open(path).convert("RGB")
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.*image - 1.
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim"):])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")] #[250,]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+def chunk(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda")
+vqgan_config = OmegaConf.load("./configs/autoencoder/autoencoder_kl_64x64x4_resi.yaml")
+vq_model = load_model_from_config(vqgan_config, './vqgan_cfw_00011.ckpt')
+vq_model = vq_model.to(device)
+
+os.makedirs('output', exist_ok=True)
+
+def inference(image, upscale, dec_w, seed, model_type, ddpm_steps, colorfix_type):
+ """Run a single prediction on the model"""
+ precision_scope = autocast
+ vq_model.decoder.fusion_w = dec_w
+ seed_everything(seed)
+
+ if model_type == '512':
+ config = OmegaConf.load("./configs/stableSRNew/v2-finetune_text_T_512.yaml")
+ model = load_model_from_config(config, "./stablesr_000117.ckpt")
+ min_size = 512
+ else:
+ config = OmegaConf.load("./configs/stableSRNew/v2-finetune_text_T_768v.yaml")
+ model = load_model_from_config(config, "./stablesr_768v_000139.ckpt")
+ min_size = 768
+
+ model = model.to(device)
+ model.configs = config
+ model.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=0.00085, linear_end=0.0120, cosine_s=8e-3)
+ model.num_timesteps = 1000
+
+ sqrt_alphas_cumprod = copy.deepcopy(model.sqrt_alphas_cumprod)
+ sqrt_one_minus_alphas_cumprod = copy.deepcopy(model.sqrt_one_minus_alphas_cumprod)
+
+ use_timesteps = set(space_timesteps(1000, [ddpm_steps]))
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ timestep_map = []
+ for i, alpha_cumprod in enumerate(model.alphas_cumprod):
+ if i in use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ timestep_map.append(i)
+ new_betas = [beta.data.cpu().numpy() for beta in new_betas]
+ model.register_schedule(given_betas=np.array(new_betas), timesteps=len(new_betas))
+ model.num_timesteps = 1000
+ model.ori_timesteps = list(use_timesteps)
+ model.ori_timesteps.sort()
+ model = model.to(device)
+
+ try: # global try
+ with torch.no_grad():
+ with precision_scope("cuda"):
+ with model.ema_scope():
+ init_image = load_img(image)
+ init_image = F.interpolate(
+ init_image,
+ size=(int(init_image.size(-2)*upscale),
+ int(init_image.size(-1)*upscale)),
+ mode='bicubic',
+ )
+
+ if init_image.size(-1) < min_size or init_image.size(-2) < min_size:
+ ori_size = init_image.size()
+ rescale = min_size * 1.0 / min(init_image.size(-2), init_image.size(-1))
+ new_h = max(int(ori_size[-2]*rescale), min_size)
+ new_w = max(int(ori_size[-1]*rescale), min_size)
+ init_template = F.interpolate(
+ init_image,
+ size=(new_h, new_w),
+ mode='bicubic',
+ )
+ else:
+ init_template = init_image
+ rescale = 1
+ init_template = init_template.clamp(-1, 1)
+ assert init_template.size(-1) >= min_size
+ assert init_template.size(-2) >= min_size
+
+ init_template = init_template.type(torch.float16).to(device)
+
+ if init_template.size(-1) <= 1280 or init_template.size(-2) <= 1280:
+ init_latent_generator, enc_fea_lq = vq_model.encode(init_template)
+ init_latent = model.get_first_stage_encoding(init_latent_generator)
+ text_init = ['']*init_template.size(0)
+ semantic_c = model.cond_stage_model(text_init)
+
+ noise = torch.randn_like(init_latent)
+
+ t = repeat(torch.tensor([999]), '1 -> b', b=init_image.size(0))
+ t = t.to(device).long()
+ x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
+
+ if init_template.size(-1)<= min_size and init_template.size(-2) <= min_size:
+ samples, _ = model.sample(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True)
+ else:
+ samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_template.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=init_template.size(0))
+ x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
+ if colorfix_type == 'adain':
+ x_samples = adaptive_instance_normalization(x_samples, init_template)
+ elif colorfix_type == 'wavelet':
+ x_samples = wavelet_reconstruction(x_samples, init_template)
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+ else:
+ im_spliter = ImageSpliterTh(init_template, 1280, 1000, sf=1)
+ for im_lq_pch, index_infos in im_spliter:
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(im_lq_pch)) # move to latent space
+ text_init = ['']*init_latent.size(0)
+ semantic_c = model.cond_stage_model(text_init)
+ noise = torch.randn_like(init_latent)
+ # If you would like to start from the intermediate steps, you can add noise to LR to the specific steps.
+ t = repeat(torch.tensor([999]), '1 -> b', b=init_template.size(0))
+ t = t.to(device).long()
+ x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
+ # x_T = noise
+ samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=im_lq_pch.size(0), timesteps=ddpm_steps, time_replace=ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(min_size/8), tile_overlap=min_size//16, batch_size_sample=im_lq_pch.size(0))
+ _, enc_fea_lq = vq_model.encode(im_lq_pch)
+ x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
+ if colorfix_type == 'adain':
+ x_samples = adaptive_instance_normalization(x_samples, im_lq_pch)
+ elif colorfix_type == 'wavelet':
+ x_samples = wavelet_reconstruction(x_samples, im_lq_pch)
+ im_spliter.update(x_samples, index_infos)
+ x_samples = im_spliter.gather()
+ x_samples = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0)
+
+ if rescale > 1:
+ x_samples = F.interpolate(
+ x_samples,
+ size=(int(init_image.size(-2)),
+ int(init_image.size(-1))),
+ mode='bicubic',
+ )
+ x_samples = x_samples.clamp(0, 1)
+ x_sample = 255. * rearrange(x_samples[0].cpu().numpy(), 'c h w -> h w c')
+ restored_img = x_sample.astype(np.uint8)
+ Image.fromarray(x_sample.astype(np.uint8)).save(f'output/out.png')
+
+ return restored_img, f'output/out.png'
+ except Exception as error:
+ print('Global exception', error)
+ return None, None
+
+
+title = "Exploiting Diffusion Prior for Real-World Image Super-Resolution"
+description = r"""
+Official Gradio demo for Exploiting Diffusion Prior for Real-World Image Super-Resolution.
+🔥 StableSR is a general image super-resolution algorithm for real-world and AIGC images.
+"""
+article = r"""
+If StableSR is helpful, please help to ⭐ the Github Repo. Thanks!
+[![GitHub Stars](https://img.shields.io/github/stars/IceClear/StableSR?style=social)](https://github.com/IceClear/StableSR)
+
+---
+
+📝 **Citation**
+
+If our work is useful for your research, please consider citing:
+```bibtex
+@inproceedings{wang2023exploiting,
+ author = {Wang, Jianyi and Yue, Zongsheng and Zhou, Shangchen and Chan, Kelvin CK and Loy, Chen Change},
+ title = {Exploiting Diffusion Prior for Real-World Image Super-Resolution},
+ booktitle = {arXiv preprint arXiv:2305.07015},
+ year = {2023}
+}
+```
+
+📋 **License**
+
+This project is licensed under S-Lab License 1.0.
+Redistribution and use for non-commercial purposes should follow this license.
+
+📧 **Contact**
+
+If you have any questions, please feel free to reach me out at iceclearwjy@gmail.com.
+
+
+ 🤗 Find Me:
+
![Twitter Follow](https://img.shields.io/twitter/follow/Iceclearwjy?label=%40Iceclearwjy&style=social)
+
![Github Follow](https://img.shields.io/github/followers/IceClear?style=social)
+
+
+
+"""
+
+demo = gr.Interface(
+ inference, [
+ gr.inputs.Image(type="filepath", label="Input"),
+ gr.inputs.Number(default=1, label="Rescaling_Factor (Large images require huge time)"),
+ gr.Slider(0, 1, value=0.5, step=0.01, label='CFW_Fidelity (0 for better quality, 1 for better identity)'),
+ gr.inputs.Number(default=42, label="Seeds"),
+ gr.Dropdown(
+ choices=["512", "768v"],
+ value="512",
+ label="Model",
+ ),
+ gr.Slider(10, 1000, value=200, step=1, label='Sampling timesteps for DDPM'),
+ gr.Dropdown(
+ choices=["none", "adain", "wavelet"],
+ value="adain",
+ label="Color_Correction",
+ ),
+ ], [
+ gr.outputs.Image(type="numpy", label="Output"),
+ gr.outputs.File(label="Download the output")
+ ],
+ title=title,
+ description=description,
+ article=article,
+ examples=[
+ ['./01.png', 4, 0.5, 42, "512", 200, "adain"],
+ ['./02.png', 4, 0.5, 42, "512", 200, "adain"],
+ ['./03.png', 4, 0.5, 42, "512", 200, "adain"],
+ ['./04.png', 4, 0.5, 42, "512", 200, "adain"],
+ ['./05.png', 4, 0.5, 42, "512", 200, "adain"]
+ ]
+ )
+
+demo.queue(concurrency_count=1)
+demo.launch(share=True)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f0c1810ed337dd86ffbf28b152325d6bd4e7f5db
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,24 @@
+torch==1.13.1
+torchvision==0.14.1
+albumentations==1.3.0
+opencv-python==4.6.0.66
+imageio==2.9.0
+numpy==1.23.1
+imageio-ffmpeg==0.4.2
+pytorch-lightning==1.4.2
+omegaconf==2.1.1
+test-tube>=0.7.5
+streamlit==1.12.1
+einops==0.3.0
+transformers==4.19.2
+webdataset==0.2.5
+kornia==0.6
+open_clip_torch==2.0.2
+invisible-watermark>=0.1.5
+streamlit-drawable-canvas==0.8.0
+torchmetrics==0.6.0
+xformers
+triton
+matplotlib
+wandb
+pillow