Nguyễn Bá Thiêm
Add image super resolution functionality
b16ab70
raw
history blame
51.1 kB
import numpy as np
import gdown
import gc
import os
import random
import time
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import transforms
from basicsr.archs.arch_util import to_2tuple, trunc_normal_
from einops import rearrange
import math
class ChannelAttention(nn.Module):
"""Channel attention used in RCAN.
Args:
num_feat (int): Channel number of intermediate features.
squeeze_factor (int): Channel squeeze factor. Default: 16.
"""
def __init__(self, num_feat, squeeze_factor=16):
super(ChannelAttention, self).__init__()
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
nn.Sigmoid())
def forward(self, x):
y = self.attention(x)
return x * y
class CAB(nn.Module):
def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
super(CAB, self).__init__()
self.cab = nn.Sequential(
nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
nn.GELU(),
nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
ChannelAttention(num_feat, squeeze_factor)
)
def forward(self, x):
return self.cab(x)
def window_partition(x, window_size):
"""
Args:
x: (b, h, w, c)
window_size (int): window size
Returns:
windows: (num_windows*b, window_size, window_size, c)
"""
b, h, w, c = x.shape
x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
return windows
def window_reverse(windows, window_size, h, w):
"""
Args:
windows: (num_windows*b, window_size, window_size, c)
window_size (int): Window size
h (int): Height of image
w (int): Width of image
Returns:
x: (b, h, w, c)
"""
b = int(windows.shape[0] / (h * w / window_size / window_size))
x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, rpi, mask=None):
"""
Args:
x: input features with shape of (num_windows*b, n, c)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
b_, n, c = x.shape
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nw = mask.shape[0]
attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, n, n)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class OCAB(nn.Module):
# overlapping cross-attention block
def __init__(self, dim,
input_resolution,
window_size,
overlap_ratio,
num_heads,
qkv_bias=True,
qk_scale=None,
mlp_ratio=2,
norm_layer=nn.LayerNorm
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.overlap_win_size = int(window_size * overlap_ratio) + window_size
self.norm1 = norm_layer(dim)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
self.proj = nn.Linear(dim,dim)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
def forward(self, x, x_size, rpi):
h, w = x_size
b, _, c = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(b, h, w, c)
qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w
q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c
kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w
# partition windows
q_windows = window_partition(q, self.window_size) # nw*b, window_size, window_size, c
q_windows = q_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
kv_windows = self.unfold(kv) # b, c*w*w, nw
kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch', nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous() # 2, nw*b, ow*ow, c
k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c
b_, nq, _ = q_windows.shape
_, n, _ = k_windows.shape
d = self.dim // self.num_heads
q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, nq, d
k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d
v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, -1) # ws*ws, wse*wse, nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, ws*ws, wse*wse
attn = attn + relative_position_bias.unsqueeze(0)
attn = self.softmax(attn)
attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
x = x.view(b, h * w, self.dim)
x = self.proj(x) + shortcut
x = x + self.mlp(self.norm2(x))
return x
class HAB(nn.Module):
r""" Hybrid Attention Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
compress_ratio=3,
squeeze_factor=30,
conv_scale=0.01,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
self.conv_scale = conv_scale
self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, x_size, rpi_sa, attn_mask):
h, w = x_size
b, _, c = x.shape
# assert seq_len == h * w, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(b, h, w, c)
# Conv_X
conv_x = self.conv_block(x.permute(0, 3, 1, 2))
conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = attn_mask
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
# reverse cyclic shift
if self.shift_size > 0:
attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
attn_x = shifted_x
attn_x = attn_x.view(b, h * w, c)
# FFN
x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class AttenBlocks(nn.Module):
""" A series of attention blocks for one RHAG.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
compress_ratio,
squeeze_factor,
conv_scale,
overlap_ratio,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
HAB(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
compress_ratio=compress_ratio,
squeeze_factor=squeeze_factor,
conv_scale=conv_scale,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer) for i in range(depth)
])
# OCAB
self.overlap_attn = OCAB(
dim=dim,
input_resolution=input_resolution,
window_size=window_size,
overlap_ratio=overlap_ratio,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size, params):
for blk in self.blocks:
x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
x = self.overlap_attn(x, x_size, params['rpi_oca'])
if self.downsample is not None:
x = self.downsample(x)
return x
class RHAG(nn.Module):
"""Residual Hybrid Attention Group (RHAG).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
compress_ratio,
squeeze_factor,
conv_scale,
overlap_ratio,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
img_size=224,
patch_size=4,
resi_connection='1conv'):
super(RHAG, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = AttenBlocks(
dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
compress_ratio=compress_ratio,
squeeze_factor=squeeze_factor,
conv_scale=conv_scale,
overlap_ratio=overlap_ratio,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == 'identity':
self.conv = nn.Identity()
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
def forward(self, x, x_size, params):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
if self.norm is not None:
x = self.norm(x)
return x
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
return x
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class HAT(nn.Module):
r""" Hybrid Attention Transformer
A PyTorch implementation of : `Activating More Pixels in Image Super-Resolution Transformer`.
Some codes are based on SwinIR.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self,
img_size=64,
patch_size=1,
in_chans=3,
embed_dim=96,
depths=(6, 6, 6, 6),
num_heads=(6, 6, 6, 6),
window_size=7,
compress_ratio=3,
squeeze_factor=30,
conv_scale=0.01,
overlap_ratio=0.5,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
upscale=2,
img_range=1.,
upsampler='',
resi_connection='1conv',
**kwargs):
super(HAT, self).__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.overlap_ratio = overlap_ratio
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
# relative position index
relative_position_index_SA = self.calculate_rpi_sa()
relative_position_index_OCA = self.calculate_rpi_oca()
self.register_buffer('relative_position_index_SA', relative_position_index_SA)
self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
# ------------------------- 1, shallow feature extraction ------------------------- #
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
# ------------------------- 2, deep feature extraction ------------------------- #
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Hybrid Attention Groups (RHAG)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RHAG(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
compress_ratio=compress_ratio,
squeeze_factor=squeeze_factor,
conv_scale=conv_scale,
overlap_ratio=overlap_ratio,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == 'identity':
self.conv_after_body = nn.Identity()
# ------------------------- 3, high quality image reconstruction ------------------------- #
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def calculate_rpi_sa(self):
# calculate relative position index for SA
coords_h = torch.arange(self.window_size)
coords_w = torch.arange(self.window_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
return relative_position_index
def calculate_rpi_oca(self):
# calculate relative position index for OCA
window_size_ori = self.window_size
window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
coords_h = torch.arange(window_size_ori)
coords_w = torch.arange(window_size_ori)
coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws
coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws
coords_h = torch.arange(window_size_ext)
coords_w = torch.arange(window_size_ext)
coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse
coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse
relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None] # 2, ws*ws, wse*wse
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # ws*ws, wse*wse, 2
relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1 # shift to start from 0
relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
relative_position_index = relative_coords.sum(-1)
return relative_position_index
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
h, w = x_size
img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
h_slices = (slice(0, -self.window_size), slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size), slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
# Calculate attention mask and relative position index in advance to speed up inference.
# The original code is very time-consuming for large window size.
attn_mask = self.calculate_mask(x_size).to(x.device)
params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size, params)
x = self.norm(x) # b seq_len c
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
x = x / self.img_range + self.mean
return x
# ------------------------------ HYPERPARAMS ------------------------------ #
config = {
"network_g": {
"type": "HAT",
"upscale": 4,
"in_chans": 3,
"img_size": 64,
"window_size": 16,
"compress_ratio": 3,
"squeeze_factor": 30,
"conv_scale": 0.01,
"overlap_ratio": 0.5,
"img_range": 1.,
"depths": [6, 6, 6, 6, 6, 6],
"embed_dim": 180,
"num_heads": [6, 6, 6, 6, 6, 6],
"mlp_ratio": 2,
"upsampler": 'pixelshuffle',
"resi_connection": '1conv'
},
"train": {
"ema_decay": 0.999,
"optim_g": {
"type": "Adam",
"lr": 1e-4,
"weight_decay": 0,
"betas": [0.9, 0.99]
},
"scheduler": {
"type": "MultiStepLR",
"milestones": [12, 20, 25, 30],
"gamma": 0.5
},
"total_iter": 30,
"warmup_iter": -1,
"pixel_opt": {
"type": "L1Loss",
"loss_weight": 1.0,
"reduction": "mean"
}
},
'tile':{
'tile_size': 56,
'tile_pad': 4
}
}
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('mps' if torch.backends.mps.is_built() else 'cpu')
print('device', DEVICE)
class Network:
def __init__(self,config = config, device=DEVICE):
self.config = config
self.device = device
self.model = HAT(
upscale=self.config['network_g']['upscale'],
in_chans=self.config['network_g']['in_chans'],
img_size=self.config['network_g']['img_size'],
window_size=self.config['network_g']['window_size'],
compress_ratio=self.config['network_g']['compress_ratio'],
squeeze_factor=self.config['network_g']['squeeze_factor'],
conv_scale=self.config['network_g']['conv_scale'],
overlap_ratio=self.config['network_g']['overlap_ratio'],
img_range=self.config['network_g']['img_range'],
depths=self.config['network_g']['depths'],
embed_dim=self.config['network_g']['embed_dim'],
num_heads=self.config['network_g']['num_heads'],
mlp_ratio=self.config['network_g']['mlp_ratio'],
upsampler=self.config['network_g']['upsampler'],
resi_connection=self.config['network_g']['resi_connection']
).to(self.device)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['train']['optim_g']['lr'], weight_decay=config['train']['optim_g']['weight_decay'],betas=tuple(config['train']['optim_g']['betas']))
def load_network(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model'])
self.optimizer.load_state_dict(checkpoint['optimizer']) # before create and load scheduler
def pre_process(self):
# pad to multiplication of window_size
window_size = self.config['network_g']['window_size'] * 4
self.scale = self.config['network_g']['upscale']
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.input_tile.size()
if h % window_size != 0:
self.mod_pad_h = window_size - h % window_size
# Loop to add padding to the height until it's a multiple of window_size
for i in range(self.mod_pad_h):
self.input_tile = F.pad(self.input_tile, (0, 0, 0, 1), 'reflect')
if w % window_size != 0:
# Loop to add padding to the width until it's a multiple of window_size
self.mod_pad_w = window_size - w % window_size
for i in range(self.mod_pad_w):
self.input_tile = F.pad(self.input_tile, (0, 1, 0, 0), 'reflect')
def post_process(self):
_, _, h, w = self.output_tile.size()
self.output_tile = self.output_tile[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
def tile_valid(self, lr_images):
"""
Process all tiles of an image in a batch and then merge them back into the output image.
"""
batch, channel, height, width = lr_images.shape
output_height = height * self.config['network_g']['upscale']
output_width = width * self.config['network_g']['upscale']
output_shape = (batch, channel, output_height, output_width)
# Start with black image for output
sr_images = lr_images.new_zeros(output_shape)
tiles_x = math.ceil(width / self.config['tile']['tile_size'])
tiles_y = math.ceil(height / self.config['tile']['tile_size'])
tile_list = []
# Extract all tiles
for y in range(tiles_y):
for x in range(tiles_x):
input_start_x = x * self.config['tile']['tile_size']
input_end_x = min(input_start_x + self.config['tile']['tile_size'], width)
input_start_y = y * self.config['tile']['tile_size']
input_end_y = min(input_start_y + self.config['tile']['tile_size'], height)
input_start_x_pad = max(input_start_x - self.config['tile']['tile_pad'], 0)
input_end_x_pad = min(input_end_x + self.config['tile']['tile_pad'], width)
input_start_y_pad = max(input_start_y - self.config['tile']['tile_pad'], 0)
input_end_y_pad = min(input_end_y + self.config['tile']['tile_pad'], height)
# Extract tile and add to list
self.input_tile = lr_images[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
self.pre_process()
tile_list.append(self.input_tile.clone())
output_tiles = []
# Determine the number of tiles to process per batch
batch_size = 16 # Adjust based on your specific situation
for i in range(0, len(tile_list), batch_size):
# Extract a batch of tiles
batch = tile_list[i:i + batch_size]
tile_batch = torch.cat(batch, dim=0) # This creates a batch of tiles
# Process the batch through the model
self.model.eval()
with torch.no_grad():
# Ensure that each tile processed by the model returns a 3D tensor (C, H, W)
output_batch = self.model(tile_batch)
# Extend the list of processed tiles
output_tiles.append(output_batch) # Assuming output_batch is 4D
# Concatenate along the first dimension to combine all the processed tiles
output_tile_batch = torch.cat(output_tiles, dim=0) # This should be 4D now
for y in range(tiles_y):
for x in range(tiles_x):
# input tile area on total image
input_start_x = x * self.config['tile']['tile_size']
input_end_x = min(input_start_x + self.config['tile']['tile_size'], width)
input_start_y = y * self.config['tile']['tile_size']
input_end_y = min(input_start_y + self.config['tile']['tile_size'], height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.config['tile']['tile_pad'], 0)
input_end_x_pad = min(input_end_x + self.config['tile']['tile_pad'], width)
input_start_y_pad = max(input_start_y - self.config['tile']['tile_pad'], 0)
input_end_y_pad = min(input_end_y + self.config['tile']['tile_pad'], height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x
self.pre_process()
self.output_tile = output_tile_batch[tile_idx, :, :, :].unsqueeze(0).clone()
self.post_process()
# output tile area on total image
output_start_x = input_start_x * self.config['network_g']['upscale']
output_end_x = input_end_x * self.config['network_g']['upscale']
output_start_y = input_start_y * self.config['network_g']['upscale']
output_end_y = input_end_y * self.config['network_g']['upscale']
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.config['network_g']['upscale']
output_end_x_tile = output_start_x_tile + input_tile_width * self.config['network_g']['upscale']
output_start_y_tile = (input_start_y - input_start_y_pad) * self.config['network_g']['upscale']
output_end_y_tile = output_start_y_tile + input_tile_height * self.config['network_g']['upscale']
# put tile into output image
sr_images[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = self.output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
del self.input_tile, self.output_tile, tile_batch, tile_list, output_tile_batch, output_tiles
gc.collect()
torch.cuda.empty_cache()
return sr_images
def inference(self, lr_image, hr_image = None, deployment = False):
"""
- lr_image: torch.Tensor
3D Tensor (C, H, W)
- hr_image: torch.Tesnor
3D Tensor (C, H, W). This parameter is optional, for comparing the model output and the
ground-truth high-res image. If used solely for inference, skip this. Default is None/
"""
lr_image = lr_image.unsqueeze(0).to(self.device)
self.for_inference = True
with torch.no_grad():
sr_image = self.tile_valid(lr_image)
sr_image = torch.clamp(sr_image, 0, 1)
if deployment:
return sr_image.squeeze(0)
else:
lr_image = lr_image.squeeze(0)
sr_image = sr_image.squeeze(0)
print(">> Size of low-res image:", lr_image.size())
print(">> Size of super-res image:", sr_image.size())
if hr_image != None:
print(">> Size of high-res image:", hr_image.size())
if hr_image != None:
fig, axes = plt.subplots(1, 3, figsize=(10, 6))
axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
axes[0].set_title('Low Resolution')
axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
axes[1].set_title('Super Resolution')
axes[2].imshow(hr_image.cpu().detach().permute((1, 2, 0)))
axes[2].set_title('High Resolution')
for ax in axes.flat:
ax.axis('off')
else:
fig, axes = plt.subplots(1, 2, figsize=(10, 6))
axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
axes[0].set_title('Low Resolution')
axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
axes[1].set_title('Super Resolution')
for ax in axes.flat:
ax.axis('off')
plt.tight_layout()
plt.show()
return sr_image
def HAT_for_deployment(lr_image, model_path = 'models/HAT/hat_model_checkpoint_best.pth'):
lr_image = transforms.functional.to_tensor(lr_image)
hat = Network()
hat.load_network(model_path)
t1 = time.time()
sr_image = hat.inference(lr_image, deployment=True).cpu().numpy()
t2 = time.time()
print("Time taken to infer:", t2 - t1)
# If image is in [C, H, W] format, transpose it to [H, W, C]
sr_image = np.transpose(sr_image, (1, 2, 0))
if sr_image.max() <= 1.0:
sr_image = (sr_image * 255).astype(np.uint8)
sr_image = Image.fromarray(sr_image)
return sr_image
if __name__ == "__main__":
import os
import sys
# Getting to the true directory
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../"))
# Define the model's file path and the Google Drive link
model_path = 'models/HAT/hat_model_checkpoint_best.pth'
gdrive_id = '1LHIUM7YoUDk8cXWzVZhroAcA1xXi-d87' # Replace with your actual Google Drive file URL
# Check if the model file exists
if not os.path.exists(model_path):
print(f"Model file not found at {model_path}. Downloading from Google Drive...")
# Ensure the directory exists, as gdown will not automatically create directory paths
os.makedirs(os.path.dirname(model_path), exist_ok=True)
# Download the file from Google Drive
# gdown.download(id=gdrive_id, output=model_path, quiet=False)
else:
print(f"Model file found at {model_path}. No need to download.")
image_path = "images/demo.png"
lr_image = Image.open(image_path)
# lr_image = transforms.functional.to_tensor(lr_image)
# hat = Network()
# hat.load_network(model_path)
# hat.inference(lr_image)
print(HAT_for_deployment(lr_image, model_path))