|
import functools |
|
|
|
from tensorflow.keras import layers |
|
|
|
from .attentions import RDCAB |
|
from .misc_gating import ResidualSplitHeadMultiAxisGmlpLayer |
|
|
|
Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same") |
|
|
|
|
|
def BottleneckBlock( |
|
features: int, |
|
block_size, |
|
grid_size, |
|
num_groups: int = 1, |
|
block_gmlp_factor: int = 2, |
|
grid_gmlp_factor: int = 2, |
|
input_proj_factor: int = 2, |
|
channels_reduction: int = 4, |
|
dropout_rate: float = 0.0, |
|
use_bias: bool = True, |
|
name: str = "bottleneck_block", |
|
): |
|
"""The bottleneck block consisting of multi-axis gMLP block and RDCAB.""" |
|
|
|
def apply(x): |
|
|
|
x = Conv1x1(filters=features, use_bias=use_bias, name=f"{name}_input_proj")(x) |
|
shortcut_long = x |
|
|
|
for i in range(num_groups): |
|
x = ResidualSplitHeadMultiAxisGmlpLayer( |
|
grid_size=grid_size, |
|
block_size=block_size, |
|
grid_gmlp_factor=grid_gmlp_factor, |
|
block_gmlp_factor=block_gmlp_factor, |
|
input_proj_factor=input_proj_factor, |
|
use_bias=use_bias, |
|
dropout_rate=dropout_rate, |
|
name=f"{name}_SplitHeadMultiAxisGmlpLayer_{i}", |
|
)(x) |
|
|
|
x = RDCAB( |
|
num_channels=features, |
|
reduction=channels_reduction, |
|
use_bias=use_bias, |
|
name=f"{name}_channel_attention_block_1_{i}", |
|
)(x) |
|
|
|
|
|
x = x + shortcut_long |
|
return x |
|
|
|
return apply |
|
|