diff --git a/ControlNet b/ControlNet
new file mode 160000
index 0000000000000000000000000000000000000000..c1cb2393802d418937ea458cce4abc545b6f95d3
--- /dev/null
+++ b/ControlNet
@@ -0,0 +1 @@
+Subproject commit c1cb2393802d418937ea458cce4abc545b6f95d3
diff --git a/annotator/__pycache__/util.cpython-38.pyc b/annotator/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5ad1057bf5db1914f2a5a2b8d1390d99e452982
Binary files /dev/null and b/annotator/__pycache__/util.cpython-38.pyc differ
diff --git a/annotator/canny/__init__.py b/annotator/canny/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb0da951dc838ec9dec2131007e036113281800b
--- /dev/null
+++ b/annotator/canny/__init__.py
@@ -0,0 +1,6 @@
+import cv2
+
+
+class CannyDetector:
+    def __call__(self, img, low_threshold, high_threshold):
+        return cv2.Canny(img, low_threshold, high_threshold)
diff --git a/annotator/ckpts/ckpts.txt b/annotator/ckpts/ckpts.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1978551fb2a9226814eaf58459f414fcfac4e69b
--- /dev/null
+++ b/annotator/ckpts/ckpts.txt
@@ -0,0 +1 @@
+Weights here.
\ No newline at end of file
diff --git a/annotator/ckpts/network-bsds500.pth b/annotator/ckpts/network-bsds500.pth
new file mode 100644
index 0000000000000000000000000000000000000000..36cff8560c17530f48cbb9a43c6e9a0d6f704af3
--- /dev/null
+++ b/annotator/ckpts/network-bsds500.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:58a858782f5fa3e0ca3dc92e7a1a609add93987d77be3dfa54f8f8419d881a94
+size 58871680
diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..56532c374df5c26f9ec53e2ac0dd924f4534bbdd
--- /dev/null
+++ b/annotator/hed/__init__.py
@@ -0,0 +1,132 @@
+import numpy as np
+import cv2
+import os
+import torch
+from einops import rearrange
+from annotator.util import annotator_ckpts_path
+
+
+class Network(torch.nn.Module):
+    def __init__(self, model_path):
+        super().__init__()
+
+        self.netVggOne = torch.nn.Sequential(
+            torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netVggTwo = torch.nn.Sequential(
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),
+            torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netVggThr = torch.nn.Sequential(
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),
+            torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netVggFou = torch.nn.Sequential(
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),
+            torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netVggFiv = torch.nn.Sequential(
+            torch.nn.MaxPool2d(kernel_size=2, stride=2),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False),
+            torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+            torch.nn.ReLU(inplace=False)
+        )
+
+        self.netScoreOne = torch.nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
+        self.netScoreTwo = torch.nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
+        self.netScoreThr = torch.nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
+        self.netScoreFou = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+        self.netScoreFiv = torch.nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
+
+        self.netCombine = torch.nn.Sequential(
+            torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
+            torch.nn.Sigmoid()
+        )
+
+        self.load_state_dict({strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.load(model_path).items()})
+
+    def forward(self, tenInput):
+        tenInput = tenInput * 255.0
+        tenInput = tenInput - torch.tensor(data=[104.00698793, 116.66876762, 122.67891434], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
+
+        tenVggOne = self.netVggOne(tenInput)
+        tenVggTwo = self.netVggTwo(tenVggOne)
+        tenVggThr = self.netVggThr(tenVggTwo)
+        tenVggFou = self.netVggFou(tenVggThr)
+        tenVggFiv = self.netVggFiv(tenVggFou)
+
+        tenScoreOne = self.netScoreOne(tenVggOne)
+        tenScoreTwo = self.netScoreTwo(tenVggTwo)
+        tenScoreThr = self.netScoreThr(tenVggThr)
+        tenScoreFou = self.netScoreFou(tenVggFou)
+        tenScoreFiv = self.netScoreFiv(tenVggFiv)
+
+        tenScoreOne = torch.nn.functional.interpolate(input=tenScoreOne, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+        tenScoreTwo = torch.nn.functional.interpolate(input=tenScoreTwo, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+        tenScoreThr = torch.nn.functional.interpolate(input=tenScoreThr, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+        tenScoreFou = torch.nn.functional.interpolate(input=tenScoreFou, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+        tenScoreFiv = torch.nn.functional.interpolate(input=tenScoreFiv, size=(tenInput.shape[2], tenInput.shape[3]), mode='bilinear', align_corners=False)
+
+        return self.netCombine(torch.cat([ tenScoreOne, tenScoreTwo, tenScoreThr, tenScoreFou, tenScoreFiv ], 1))
+
+
+class HEDdetector:
+    def __init__(self):
+        remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/network-bsds500.pth"
+        modelpath = os.path.join(annotator_ckpts_path, "network-bsds500.pth")
+        if not os.path.exists(modelpath):
+            from basicsr.utils.download_util import load_file_from_url
+            load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+        self.netNetwork = Network(modelpath).cuda().eval()
+
+    def __call__(self, input_image):
+        assert input_image.ndim == 3
+        input_image = input_image[:, :, ::-1].copy()
+        with torch.no_grad():
+            image_hed = torch.from_numpy(input_image).float().cuda()
+            image_hed = image_hed / 255.0
+            image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
+            edge = self.netNetwork(image_hed)[0]
+            edge = (edge.cpu().numpy() * 255.0).clip(0, 255).astype(np.uint8)
+            return edge[0]
+
+
+def nms(x, t, s):
+    x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+    f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+    f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+    f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+    f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+    y = np.zeros_like(x)
+
+    for f in [f1, f2, f3, f4]:
+        np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+    z = np.zeros_like(y, dtype=np.uint8)
+    z[y > t] = 255
+    return z
diff --git a/annotator/hed/__pycache__/__init__.cpython-38.pyc b/annotator/hed/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b25032b8b52e5c2550be572024b86d7e565e0068
Binary files /dev/null and b/annotator/hed/__pycache__/__init__.cpython-38.pyc differ
diff --git a/annotator/midas/__init__.py b/annotator/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc5ac03eea6f5ba7968706f1863c8bc4f8aaaf6a
--- /dev/null
+++ b/annotator/midas/__init__.py
@@ -0,0 +1,38 @@
+import cv2
+import numpy as np
+import torch
+
+from einops import rearrange
+from .api import MiDaSInference
+
+
+class MidasDetector:
+    def __init__(self):
+        self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
+
+    def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1):
+        assert input_image.ndim == 3
+        image_depth = input_image
+        with torch.no_grad():
+            image_depth = torch.from_numpy(image_depth).float().cuda()
+            image_depth = image_depth / 127.5 - 1.0
+            image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
+            depth = self.model(image_depth)[0]
+
+            depth_pt = depth.clone()
+            depth_pt -= torch.min(depth_pt)
+            depth_pt /= torch.max(depth_pt)
+            depth_pt = depth_pt.cpu().numpy()
+            depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
+
+            depth_np = depth.cpu().numpy()
+            x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3)
+            y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3)
+            z = np.ones_like(x) * a
+            x[depth_pt < bg_th] = 0
+            y[depth_pt < bg_th] = 0
+            normal = np.stack([x, y, z], axis=2)
+            normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5
+            normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
+
+            return depth_image, normal_image
diff --git a/annotator/midas/api.py b/annotator/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ab9f15bf96bbaffcee0e3e29fc9d3979d6c32e8
--- /dev/null
+++ b/annotator/midas/api.py
@@ -0,0 +1,169 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import os
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from .midas.dpt_depth import DPTDepthModel
+from .midas.midas_net import MidasNet
+from .midas.midas_net_custom import MidasNet_small
+from .midas.transforms import Resize, NormalizeImage, PrepareForNet
+from annotator.util import annotator_ckpts_path
+
+
+ISL_PATHS = {
+    "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
+    "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
+    "midas_v21": "",
+    "midas_v21_small": "",
+}
+
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def load_midas_transform(model_type):
+    # https://github.com/isl-org/MiDaS/blob/master/run.py
+    # load transform only
+    if model_type == "dpt_large":  # DPT-Large
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "dpt_hybrid":  # DPT-Hybrid
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "midas_v21":
+        net_w, net_h = 384, 384
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+    elif model_type == "midas_v21_small":
+        net_w, net_h = 256, 256
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+    else:
+        assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+    transform = Compose(
+        [
+            Resize(
+                net_w,
+                net_h,
+                resize_target=None,
+                keep_aspect_ratio=True,
+                ensure_multiple_of=32,
+                resize_method=resize_mode,
+                image_interpolation_method=cv2.INTER_CUBIC,
+            ),
+            normalization,
+            PrepareForNet(),
+        ]
+    )
+
+    return transform
+
+
+def load_model(model_type):
+    # https://github.com/isl-org/MiDaS/blob/master/run.py
+    # load network
+    model_path = ISL_PATHS[model_type]
+    if model_type == "dpt_large":  # DPT-Large
+        model = DPTDepthModel(
+            path=model_path,
+            backbone="vitl16_384",
+            non_negative=True,
+        )
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "dpt_hybrid":  # DPT-Hybrid
+        if not os.path.exists(model_path):
+            from basicsr.utils.download_util import load_file_from_url
+            load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+
+        model = DPTDepthModel(
+            path=model_path,
+            backbone="vitb_rn50_384",
+            non_negative=True,
+        )
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "midas_v21":
+        model = MidasNet(model_path, non_negative=True)
+        net_w, net_h = 384, 384
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
+    elif model_type == "midas_v21_small":
+        model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+                               non_negative=True, blocks={'expand': True})
+        net_w, net_h = 256, 256
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
+    else:
+        print(f"model_type '{model_type}' not implemented, use: --model_type large")
+        assert False
+
+    transform = Compose(
+        [
+            Resize(
+                net_w,
+                net_h,
+                resize_target=None,
+                keep_aspect_ratio=True,
+                ensure_multiple_of=32,
+                resize_method=resize_mode,
+                image_interpolation_method=cv2.INTER_CUBIC,
+            ),
+            normalization,
+            PrepareForNet(),
+        ]
+    )
+
+    return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+    MODEL_TYPES_TORCH_HUB = [
+        "DPT_Large",
+        "DPT_Hybrid",
+        "MiDaS_small"
+    ]
+    MODEL_TYPES_ISL = [
+        "dpt_large",
+        "dpt_hybrid",
+        "midas_v21",
+        "midas_v21_small",
+    ]
+
+    def __init__(self, model_type):
+        super().__init__()
+        assert (model_type in self.MODEL_TYPES_ISL)
+        model, _ = load_model(model_type)
+        self.model = model
+        self.model.train = disabled_train
+
+    def forward(self, x):
+        with torch.no_grad():
+            prediction = self.model(x)
+        return prediction
+
diff --git a/annotator/midas/midas/__init__.py b/annotator/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/annotator/midas/midas/base_model.py b/annotator/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/annotator/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+    def load(self, path):
+        """Load model from file.
+
+        Args:
+            path (str): file path
+        """
+        parameters = torch.load(path, map_location=torch.device('cpu'))
+
+        if "optimizer" in parameters:
+            parameters = parameters["model"]
+
+        self.load_state_dict(parameters)
diff --git a/annotator/midas/midas/blocks.py b/annotator/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/annotator/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+    _make_pretrained_vitb_rn50_384,
+    _make_pretrained_vitl16_384,
+    _make_pretrained_vitb16_384,
+    forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+    if backbone == "vitl16_384":
+        pretrained = _make_pretrained_vitl16_384(
+            use_pretrained, hooks=hooks, use_readout=use_readout
+        )
+        scratch = _make_scratch(
+            [256, 512, 1024, 1024], features, groups=groups, expand=expand
+        )  # ViT-L/16 - 85.0% Top1 (backbone)
+    elif backbone == "vitb_rn50_384":
+        pretrained = _make_pretrained_vitb_rn50_384(
+            use_pretrained,
+            hooks=hooks,
+            use_vit_only=use_vit_only,
+            use_readout=use_readout,
+        )
+        scratch = _make_scratch(
+            [256, 512, 768, 768], features, groups=groups, expand=expand
+        )  # ViT-H/16 - 85.0% Top1 (backbone)
+    elif backbone == "vitb16_384":
+        pretrained = _make_pretrained_vitb16_384(
+            use_pretrained, hooks=hooks, use_readout=use_readout
+        )
+        scratch = _make_scratch(
+            [96, 192, 384, 768], features, groups=groups, expand=expand
+        )  # ViT-B/16 - 84.6% Top1 (backbone)
+    elif backbone == "resnext101_wsl":
+        pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+        scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand)     # efficientnet_lite3  
+    elif backbone == "efficientnet_lite3":
+        pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+        scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand)  # efficientnet_lite3     
+    else:
+        print(f"Backbone '{backbone}' not implemented")
+        assert False
+        
+    return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+    scratch = nn.Module()
+
+    out_shape1 = out_shape
+    out_shape2 = out_shape
+    out_shape3 = out_shape
+    out_shape4 = out_shape
+    if expand==True:
+        out_shape1 = out_shape
+        out_shape2 = out_shape*2
+        out_shape3 = out_shape*4
+        out_shape4 = out_shape*8
+
+    scratch.layer1_rn = nn.Conv2d(
+        in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer2_rn = nn.Conv2d(
+        in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer3_rn = nn.Conv2d(
+        in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer4_rn = nn.Conv2d(
+        in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+
+    return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+    efficientnet = torch.hub.load(
+        "rwightman/gen-efficientnet-pytorch",
+        "tf_efficientnet_lite3",
+        pretrained=use_pretrained,
+        exportable=exportable
+    )
+    return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+    pretrained = nn.Module()
+
+    pretrained.layer1 = nn.Sequential(
+        effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+    )
+    pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+    pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+    pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+    return pretrained
+    
+
+def _make_resnet_backbone(resnet):
+    pretrained = nn.Module()
+    pretrained.layer1 = nn.Sequential(
+        resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+    )
+
+    pretrained.layer2 = resnet.layer2
+    pretrained.layer3 = resnet.layer3
+    pretrained.layer4 = resnet.layer4
+
+    return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+    resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+    return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+    """Interpolation module.
+    """
+
+    def __init__(self, scale_factor, mode, align_corners=False):
+        """Init.
+
+        Args:
+            scale_factor (float): scaling
+            mode (str): interpolation mode
+        """
+        super(Interpolate, self).__init__()
+
+        self.interp = nn.functional.interpolate
+        self.scale_factor = scale_factor
+        self.mode = mode
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: interpolated data
+        """
+
+        x = self.interp(
+            x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+        )
+
+        return x
+
+
+class ResidualConvUnit(nn.Module):
+    """Residual convolution module.
+    """
+
+    def __init__(self, features):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super().__init__()
+
+        self.conv1 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True
+        )
+
+        self.conv2 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True
+        )
+
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: output
+        """
+        out = self.relu(x)
+        out = self.conv1(out)
+        out = self.relu(out)
+        out = self.conv2(out)
+
+        return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+    """Feature fusion block.
+    """
+
+    def __init__(self, features):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super(FeatureFusionBlock, self).__init__()
+
+        self.resConfUnit1 = ResidualConvUnit(features)
+        self.resConfUnit2 = ResidualConvUnit(features)
+
+    def forward(self, *xs):
+        """Forward pass.
+
+        Returns:
+            tensor: output
+        """
+        output = xs[0]
+
+        if len(xs) == 2:
+            output += self.resConfUnit1(xs[1])
+
+        output = self.resConfUnit2(output)
+
+        output = nn.functional.interpolate(
+            output, scale_factor=2, mode="bilinear", align_corners=True
+        )
+
+        return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+    """Residual convolution module.
+    """
+
+    def __init__(self, features, activation, bn):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super().__init__()
+
+        self.bn = bn
+
+        self.groups=1
+
+        self.conv1 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+        )
+        
+        self.conv2 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+        )
+
+        if self.bn==True:
+            self.bn1 = nn.BatchNorm2d(features)
+            self.bn2 = nn.BatchNorm2d(features)
+
+        self.activation = activation
+
+        self.skip_add = nn.quantized.FloatFunctional()
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: output
+        """
+        
+        out = self.activation(x)
+        out = self.conv1(out)
+        if self.bn==True:
+            out = self.bn1(out)
+       
+        out = self.activation(out)
+        out = self.conv2(out)
+        if self.bn==True:
+            out = self.bn2(out)
+
+        if self.groups > 1:
+            out = self.conv_merge(out)
+
+        return self.skip_add.add(out, x)
+
+        # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+    """Feature fusion block.
+    """
+
+    def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super(FeatureFusionBlock_custom, self).__init__()
+
+        self.deconv = deconv
+        self.align_corners = align_corners
+
+        self.groups=1
+
+        self.expand = expand
+        out_features = features
+        if self.expand==True:
+            out_features = features//2
+        
+        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+        self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+        self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+        
+        self.skip_add = nn.quantized.FloatFunctional()
+
+    def forward(self, *xs):
+        """Forward pass.
+
+        Returns:
+            tensor: output
+        """
+        output = xs[0]
+
+        if len(xs) == 2:
+            res = self.resConfUnit1(xs[1])
+            output = self.skip_add.add(output, res)
+            # output += res
+
+        output = self.resConfUnit2(output)
+
+        output = nn.functional.interpolate(
+            output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+        )
+
+        output = self.out_conv(output)
+
+        return output
+
diff --git a/annotator/midas/midas/dpt_depth.py b/annotator/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/annotator/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+    FeatureFusionBlock,
+    FeatureFusionBlock_custom,
+    Interpolate,
+    _make_encoder,
+    forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+    return FeatureFusionBlock_custom(
+        features,
+        nn.ReLU(False),
+        deconv=False,
+        bn=use_bn,
+        expand=False,
+        align_corners=True,
+    )
+
+
+class DPT(BaseModel):
+    def __init__(
+        self,
+        head,
+        features=256,
+        backbone="vitb_rn50_384",
+        readout="project",
+        channels_last=False,
+        use_bn=False,
+    ):
+
+        super(DPT, self).__init__()
+
+        self.channels_last = channels_last
+
+        hooks = {
+            "vitb_rn50_384": [0, 1, 8, 11],
+            "vitb16_384": [2, 5, 8, 11],
+            "vitl16_384": [5, 11, 17, 23],
+        }
+
+        # Instantiate backbone and reassemble blocks
+        self.pretrained, self.scratch = _make_encoder(
+            backbone,
+            features,
+            False, # Set to true of you want to train from scratch, uses ImageNet weights
+            groups=1,
+            expand=False,
+            exportable=False,
+            hooks=hooks[backbone],
+            use_readout=readout,
+        )
+
+        self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+        self.scratch.output_conv = head
+
+
+    def forward(self, x):
+        if self.channels_last == True:
+            x.contiguous(memory_format=torch.channels_last)
+
+        layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+        out = self.scratch.output_conv(path_1)
+
+        return out
+
+
+class DPTDepthModel(DPT):
+    def __init__(self, path=None, non_negative=True, **kwargs):
+        features = kwargs["features"] if "features" in kwargs else 256
+
+        head = nn.Sequential(
+            nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+            Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+            nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(True),
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+            nn.Identity(),
+        )
+
+        super().__init__(head, **kwargs)
+
+        if path is not None:
+           self.load(path)
+
+    def forward(self, x):
+        return super().forward(x).squeeze(dim=1)
+
diff --git a/annotator/midas/midas/midas_net.py b/annotator/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/annotator/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+    """Network for monocular depth estimation.
+    """
+
+    def __init__(self, path=None, features=256, non_negative=True):
+        """Init.
+
+        Args:
+            path (str, optional): Path to saved model. Defaults to None.
+            features (int, optional): Number of features. Defaults to 256.
+            backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+        """
+        print("Loading weights: ", path)
+
+        super(MidasNet, self).__init__()
+
+        use_pretrained = False if path is None else True
+
+        self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+        self.scratch.refinenet4 = FeatureFusionBlock(features)
+        self.scratch.refinenet3 = FeatureFusionBlock(features)
+        self.scratch.refinenet2 = FeatureFusionBlock(features)
+        self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+        self.scratch.output_conv = nn.Sequential(
+            nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+            Interpolate(scale_factor=2, mode="bilinear"),
+            nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(True),
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+        )
+
+        if path:
+            self.load(path)
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input data (image)
+
+        Returns:
+            tensor: depth
+        """
+
+        layer_1 = self.pretrained.layer1(x)
+        layer_2 = self.pretrained.layer2(layer_1)
+        layer_3 = self.pretrained.layer3(layer_2)
+        layer_4 = self.pretrained.layer4(layer_3)
+
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+        out = self.scratch.output_conv(path_1)
+
+        return torch.squeeze(out, dim=1)
diff --git a/annotator/midas/midas/midas_net_custom.py b/annotator/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/annotator/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+    """Network for monocular depth estimation.
+    """
+
+    def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+        blocks={'expand': True}):
+        """Init.
+
+        Args:
+            path (str, optional): Path to saved model. Defaults to None.
+            features (int, optional): Number of features. Defaults to 256.
+            backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+        """
+        print("Loading weights: ", path)
+
+        super(MidasNet_small, self).__init__()
+
+        use_pretrained = False if path else True
+                
+        self.channels_last = channels_last
+        self.blocks = blocks
+        self.backbone = backbone
+
+        self.groups = 1
+
+        features1=features
+        features2=features
+        features3=features
+        features4=features
+        self.expand = False
+        if "expand" in self.blocks and self.blocks['expand'] == True:
+            self.expand = True
+            features1=features
+            features2=features*2
+            features3=features*4
+            features4=features*8
+
+        self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+  
+        self.scratch.activation = nn.ReLU(False)    
+
+        self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+        
+        self.scratch.output_conv = nn.Sequential(
+            nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+            Interpolate(scale_factor=2, mode="bilinear"),
+            nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+            self.scratch.activation,
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+            nn.Identity(),
+        )
+        
+        if path:
+            self.load(path)
+
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input data (image)
+
+        Returns:
+            tensor: depth
+        """
+        if self.channels_last==True:
+            print("self.channels_last = ", self.channels_last)
+            x.contiguous(memory_format=torch.channels_last)
+
+
+        layer_1 = self.pretrained.layer1(x)
+        layer_2 = self.pretrained.layer2(layer_1)
+        layer_3 = self.pretrained.layer3(layer_2)
+        layer_4 = self.pretrained.layer4(layer_3)
+        
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+        
+        out = self.scratch.output_conv(path_1)
+
+        return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+    prev_previous_type = nn.Identity()
+    prev_previous_name = ''
+    previous_type = nn.Identity()
+    previous_name = ''
+    for name, module in m.named_modules():
+        if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+            # print("FUSED ", prev_previous_name, previous_name, name)
+            torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+        elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+            # print("FUSED ", prev_previous_name, previous_name)
+            torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+        # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+        #    print("FUSED ", previous_name, name)
+        #    torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+        prev_previous_type = previous_type
+        prev_previous_name = previous_name
+        previous_type = type(module)
+        previous_name = name
\ No newline at end of file
diff --git a/annotator/midas/midas/transforms.py b/annotator/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/annotator/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+    """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+    Args:
+        sample (dict): sample
+        size (tuple): image size
+
+    Returns:
+        tuple: new size
+    """
+    shape = list(sample["disparity"].shape)
+
+    if shape[0] >= size[0] and shape[1] >= size[1]:
+        return sample
+
+    scale = [0, 0]
+    scale[0] = size[0] / shape[0]
+    scale[1] = size[1] / shape[1]
+
+    scale = max(scale)
+
+    shape[0] = math.ceil(scale * shape[0])
+    shape[1] = math.ceil(scale * shape[1])
+
+    # resize
+    sample["image"] = cv2.resize(
+        sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+    )
+
+    sample["disparity"] = cv2.resize(
+        sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+    )
+    sample["mask"] = cv2.resize(
+        sample["mask"].astype(np.float32),
+        tuple(shape[::-1]),
+        interpolation=cv2.INTER_NEAREST,
+    )
+    sample["mask"] = sample["mask"].astype(bool)
+
+    return tuple(shape)
+
+
+class Resize(object):
+    """Resize sample to given size (width, height).
+    """
+
+    def __init__(
+        self,
+        width,
+        height,
+        resize_target=True,
+        keep_aspect_ratio=False,
+        ensure_multiple_of=1,
+        resize_method="lower_bound",
+        image_interpolation_method=cv2.INTER_AREA,
+    ):
+        """Init.
+
+        Args:
+            width (int): desired output width
+            height (int): desired output height
+            resize_target (bool, optional):
+                True: Resize the full sample (image, mask, target).
+                False: Resize image only.
+                Defaults to True.
+            keep_aspect_ratio (bool, optional):
+                True: Keep the aspect ratio of the input sample.
+                Output sample might not have the given width and height, and
+                resize behaviour depends on the parameter 'resize_method'.
+                Defaults to False.
+            ensure_multiple_of (int, optional):
+                Output width and height is constrained to be multiple of this parameter.
+                Defaults to 1.
+            resize_method (str, optional):
+                "lower_bound": Output will be at least as large as the given size.
+                "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+                "minimal": Scale as least as possible.  (Output size might be smaller than given size.)
+                Defaults to "lower_bound".
+        """
+        self.__width = width
+        self.__height = height
+
+        self.__resize_target = resize_target
+        self.__keep_aspect_ratio = keep_aspect_ratio
+        self.__multiple_of = ensure_multiple_of
+        self.__resize_method = resize_method
+        self.__image_interpolation_method = image_interpolation_method
+
+    def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+        y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        if max_val is not None and y > max_val:
+            y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        if y < min_val:
+            y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        return y
+
+    def get_size(self, width, height):
+        # determine new height and width
+        scale_height = self.__height / height
+        scale_width = self.__width / width
+
+        if self.__keep_aspect_ratio:
+            if self.__resize_method == "lower_bound":
+                # scale such that output size is lower bound
+                if scale_width > scale_height:
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            elif self.__resize_method == "upper_bound":
+                # scale such that output size is upper bound
+                if scale_width < scale_height:
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            elif self.__resize_method == "minimal":
+                # scale as least as possbile
+                if abs(1 - scale_width) < abs(1 - scale_height):
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            else:
+                raise ValueError(
+                    f"resize_method {self.__resize_method} not implemented"
+                )
+
+        if self.__resize_method == "lower_bound":
+            new_height = self.constrain_to_multiple_of(
+                scale_height * height, min_val=self.__height
+            )
+            new_width = self.constrain_to_multiple_of(
+                scale_width * width, min_val=self.__width
+            )
+        elif self.__resize_method == "upper_bound":
+            new_height = self.constrain_to_multiple_of(
+                scale_height * height, max_val=self.__height
+            )
+            new_width = self.constrain_to_multiple_of(
+                scale_width * width, max_val=self.__width
+            )
+        elif self.__resize_method == "minimal":
+            new_height = self.constrain_to_multiple_of(scale_height * height)
+            new_width = self.constrain_to_multiple_of(scale_width * width)
+        else:
+            raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+        return (new_width, new_height)
+
+    def __call__(self, sample):
+        width, height = self.get_size(
+            sample["image"].shape[1], sample["image"].shape[0]
+        )
+
+        # resize sample
+        sample["image"] = cv2.resize(
+            sample["image"],
+            (width, height),
+            interpolation=self.__image_interpolation_method,
+        )
+
+        if self.__resize_target:
+            if "disparity" in sample:
+                sample["disparity"] = cv2.resize(
+                    sample["disparity"],
+                    (width, height),
+                    interpolation=cv2.INTER_NEAREST,
+                )
+
+            if "depth" in sample:
+                sample["depth"] = cv2.resize(
+                    sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+                )
+
+            sample["mask"] = cv2.resize(
+                sample["mask"].astype(np.float32),
+                (width, height),
+                interpolation=cv2.INTER_NEAREST,
+            )
+            sample["mask"] = sample["mask"].astype(bool)
+
+        return sample
+
+
+class NormalizeImage(object):
+    """Normlize image by given mean and std.
+    """
+
+    def __init__(self, mean, std):
+        self.__mean = mean
+        self.__std = std
+
+    def __call__(self, sample):
+        sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+        return sample
+
+
+class PrepareForNet(object):
+    """Prepare sample for usage as network input.
+    """
+
+    def __init__(self):
+        pass
+
+    def __call__(self, sample):
+        image = np.transpose(sample["image"], (2, 0, 1))
+        sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+        if "mask" in sample:
+            sample["mask"] = sample["mask"].astype(np.float32)
+            sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+        if "disparity" in sample:
+            disparity = sample["disparity"].astype(np.float32)
+            sample["disparity"] = np.ascontiguousarray(disparity)
+
+        if "depth" in sample:
+            depth = sample["depth"].astype(np.float32)
+            sample["depth"] = np.ascontiguousarray(depth)
+
+        return sample
diff --git a/annotator/midas/midas/vit.py b/annotator/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/annotator/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+    def __init__(self, start_index=1):
+        super(Slice, self).__init__()
+        self.start_index = start_index
+
+    def forward(self, x):
+        return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+    def __init__(self, start_index=1):
+        super(AddReadout, self).__init__()
+        self.start_index = start_index
+
+    def forward(self, x):
+        if self.start_index == 2:
+            readout = (x[:, 0] + x[:, 1]) / 2
+        else:
+            readout = x[:, 0]
+        return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+    def __init__(self, in_features, start_index=1):
+        super(ProjectReadout, self).__init__()
+        self.start_index = start_index
+
+        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+    def forward(self, x):
+        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+        features = torch.cat((x[:, self.start_index :], readout), -1)
+
+        return self.project(features)
+
+
+class Transpose(nn.Module):
+    def __init__(self, dim0, dim1):
+        super(Transpose, self).__init__()
+        self.dim0 = dim0
+        self.dim1 = dim1
+
+    def forward(self, x):
+        x = x.transpose(self.dim0, self.dim1)
+        return x
+
+
+def forward_vit(pretrained, x):
+    b, c, h, w = x.shape
+
+    glob = pretrained.model.forward_flex(x)
+
+    layer_1 = pretrained.activations["1"]
+    layer_2 = pretrained.activations["2"]
+    layer_3 = pretrained.activations["3"]
+    layer_4 = pretrained.activations["4"]
+
+    layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+    layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+    layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+    layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+    unflatten = nn.Sequential(
+        nn.Unflatten(
+            2,
+            torch.Size(
+                [
+                    h // pretrained.model.patch_size[1],
+                    w // pretrained.model.patch_size[0],
+                ]
+            ),
+        )
+    )
+
+    if layer_1.ndim == 3:
+        layer_1 = unflatten(layer_1)
+    if layer_2.ndim == 3:
+        layer_2 = unflatten(layer_2)
+    if layer_3.ndim == 3:
+        layer_3 = unflatten(layer_3)
+    if layer_4.ndim == 3:
+        layer_4 = unflatten(layer_4)
+
+    layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+    layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+    layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+    layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+    return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+    posemb_tok, posemb_grid = (
+        posemb[:, : self.start_index],
+        posemb[0, self.start_index :],
+    )
+
+    gs_old = int(math.sqrt(len(posemb_grid)))
+
+    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+    return posemb
+
+
+def forward_flex(self, x):
+    b, c, h, w = x.shape
+
+    pos_embed = self._resize_pos_embed(
+        self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+    )
+
+    B = x.shape[0]
+
+    if hasattr(self.patch_embed, "backbone"):
+        x = self.patch_embed.backbone(x)
+        if isinstance(x, (list, tuple)):
+            x = x[-1]  # last feature if backbone outputs list/tuple of features
+
+    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+    if getattr(self, "dist_token", None) is not None:
+        cls_tokens = self.cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        dist_token = self.dist_token.expand(B, -1, -1)
+        x = torch.cat((cls_tokens, dist_token, x), dim=1)
+    else:
+        cls_tokens = self.cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+
+    x = x + pos_embed
+    x = self.pos_drop(x)
+
+    for blk in self.blocks:
+        x = blk(x)
+
+    x = self.norm(x)
+
+    return x
+
+
+activations = {}
+
+
+def get_activation(name):
+    def hook(model, input, output):
+        activations[name] = output
+
+    return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+    if use_readout == "ignore":
+        readout_oper = [Slice(start_index)] * len(features)
+    elif use_readout == "add":
+        readout_oper = [AddReadout(start_index)] * len(features)
+    elif use_readout == "project":
+        readout_oper = [
+            ProjectReadout(vit_features, start_index) for out_feat in features
+        ]
+    else:
+        assert (
+            False
+        ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+    return readout_oper
+
+
+def _make_vit_b16_backbone(
+    model,
+    features=[96, 192, 384, 768],
+    size=[384, 384],
+    hooks=[2, 5, 8, 11],
+    vit_features=768,
+    use_readout="ignore",
+    start_index=1,
+):
+    pretrained = nn.Module()
+
+    pretrained.model = model
+    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+    pretrained.activations = activations
+
+    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+    # 32, 48, 136, 384
+    pretrained.act_postprocess1 = nn.Sequential(
+        readout_oper[0],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[0],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.ConvTranspose2d(
+            in_channels=features[0],
+            out_channels=features[0],
+            kernel_size=4,
+            stride=4,
+            padding=0,
+            bias=True,
+            dilation=1,
+            groups=1,
+        ),
+    )
+
+    pretrained.act_postprocess2 = nn.Sequential(
+        readout_oper[1],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[1],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.ConvTranspose2d(
+            in_channels=features[1],
+            out_channels=features[1],
+            kernel_size=2,
+            stride=2,
+            padding=0,
+            bias=True,
+            dilation=1,
+            groups=1,
+        ),
+    )
+
+    pretrained.act_postprocess3 = nn.Sequential(
+        readout_oper[2],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[2],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+    )
+
+    pretrained.act_postprocess4 = nn.Sequential(
+        readout_oper[3],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[3],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.Conv2d(
+            in_channels=features[3],
+            out_channels=features[3],
+            kernel_size=3,
+            stride=2,
+            padding=1,
+        ),
+    )
+
+    pretrained.model.start_index = start_index
+    pretrained.model.patch_size = [16, 16]
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+    pretrained.model._resize_pos_embed = types.MethodType(
+        _resize_pos_embed, pretrained.model
+    )
+
+    return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+    hooks = [5, 11, 17, 23] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model,
+        features=[256, 512, 1024, 1024],
+        hooks=hooks,
+        vit_features=1024,
+        use_readout=use_readout,
+    )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+    )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+    )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model(
+        "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+    )
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model,
+        features=[96, 192, 384, 768],
+        hooks=hooks,
+        use_readout=use_readout,
+        start_index=2,
+    )
+
+
+def _make_vit_b_rn50_backbone(
+    model,
+    features=[256, 512, 768, 768],
+    size=[384, 384],
+    hooks=[0, 1, 8, 11],
+    vit_features=768,
+    use_vit_only=False,
+    use_readout="ignore",
+    start_index=1,
+):
+    pretrained = nn.Module()
+
+    pretrained.model = model
+
+    if use_vit_only == True:
+        pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+        pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+    else:
+        pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+            get_activation("1")
+        )
+        pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+            get_activation("2")
+        )
+
+    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+    pretrained.activations = activations
+
+    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+    if use_vit_only == True:
+        pretrained.act_postprocess1 = nn.Sequential(
+            readout_oper[0],
+            Transpose(1, 2),
+            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+            nn.Conv2d(
+                in_channels=vit_features,
+                out_channels=features[0],
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            ),
+            nn.ConvTranspose2d(
+                in_channels=features[0],
+                out_channels=features[0],
+                kernel_size=4,
+                stride=4,
+                padding=0,
+                bias=True,
+                dilation=1,
+                groups=1,
+            ),
+        )
+
+        pretrained.act_postprocess2 = nn.Sequential(
+            readout_oper[1],
+            Transpose(1, 2),
+            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+            nn.Conv2d(
+                in_channels=vit_features,
+                out_channels=features[1],
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            ),
+            nn.ConvTranspose2d(
+                in_channels=features[1],
+                out_channels=features[1],
+                kernel_size=2,
+                stride=2,
+                padding=0,
+                bias=True,
+                dilation=1,
+                groups=1,
+            ),
+        )
+    else:
+        pretrained.act_postprocess1 = nn.Sequential(
+            nn.Identity(), nn.Identity(), nn.Identity()
+        )
+        pretrained.act_postprocess2 = nn.Sequential(
+            nn.Identity(), nn.Identity(), nn.Identity()
+        )
+
+    pretrained.act_postprocess3 = nn.Sequential(
+        readout_oper[2],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[2],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+    )
+
+    pretrained.act_postprocess4 = nn.Sequential(
+        readout_oper[3],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[3],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.Conv2d(
+            in_channels=features[3],
+            out_channels=features[3],
+            kernel_size=3,
+            stride=2,
+            padding=1,
+        ),
+    )
+
+    pretrained.model.start_index = start_index
+    pretrained.model.patch_size = [16, 16]
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model._resize_pos_embed = types.MethodType(
+        _resize_pos_embed, pretrained.model
+    )
+
+    return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+    pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+    model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+    hooks = [0, 1, 8, 11] if hooks == None else hooks
+    return _make_vit_b_rn50_backbone(
+        model,
+        features=[256, 512, 768, 768],
+        size=[384, 384],
+        hooks=hooks,
+        use_vit_only=use_vit_only,
+        use_readout=use_readout,
+    )
diff --git a/annotator/midas/utils.py b/annotator/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/annotator/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+    """Read pfm file.
+
+    Args:
+        path (str): path to file
+
+    Returns:
+        tuple: (data, scale)
+    """
+    with open(path, "rb") as file:
+
+        color = None
+        width = None
+        height = None
+        scale = None
+        endian = None
+
+        header = file.readline().rstrip()
+        if header.decode("ascii") == "PF":
+            color = True
+        elif header.decode("ascii") == "Pf":
+            color = False
+        else:
+            raise Exception("Not a PFM file: " + path)
+
+        dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+        if dim_match:
+            width, height = list(map(int, dim_match.groups()))
+        else:
+            raise Exception("Malformed PFM header.")
+
+        scale = float(file.readline().decode("ascii").rstrip())
+        if scale < 0:
+            # little-endian
+            endian = "<"
+            scale = -scale
+        else:
+            # big-endian
+            endian = ">"
+
+        data = np.fromfile(file, endian + "f")
+        shape = (height, width, 3) if color else (height, width)
+
+        data = np.reshape(data, shape)
+        data = np.flipud(data)
+
+        return data, scale
+
+
+def write_pfm(path, image, scale=1):
+    """Write pfm file.
+
+    Args:
+        path (str): pathto file
+        image (array): data
+        scale (int, optional): Scale. Defaults to 1.
+    """
+
+    with open(path, "wb") as file:
+        color = None
+
+        if image.dtype.name != "float32":
+            raise Exception("Image dtype must be float32.")
+
+        image = np.flipud(image)
+
+        if len(image.shape) == 3 and image.shape[2] == 3:  # color image
+            color = True
+        elif (
+            len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+        ):  # greyscale
+            color = False
+        else:
+            raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+        file.write("PF\n" if color else "Pf\n".encode())
+        file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+        endian = image.dtype.byteorder
+
+        if endian == "<" or endian == "=" and sys.byteorder == "little":
+            scale = -scale
+
+        file.write("%f\n".encode() % scale)
+
+        image.tofile(file)
+
+
+def read_image(path):
+    """Read image and output RGB image (0-1).
+
+    Args:
+        path (str): path to file
+
+    Returns:
+        array: RGB image (0-1)
+    """
+    img = cv2.imread(path)
+
+    if img.ndim == 2:
+        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+    return img
+
+
+def resize_image(img):
+    """Resize image and make it fit for network.
+
+    Args:
+        img (array): image
+
+    Returns:
+        tensor: data ready for network
+    """
+    height_orig = img.shape[0]
+    width_orig = img.shape[1]
+
+    if width_orig > height_orig:
+        scale = width_orig / 384
+    else:
+        scale = height_orig / 384
+
+    height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+    width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+    img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+    img_resized = (
+        torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+    )
+    img_resized = img_resized.unsqueeze(0)
+
+    return img_resized
+
+
+def resize_depth(depth, width, height):
+    """Resize depth map and bring to CPU (numpy).
+
+    Args:
+        depth (tensor): depth
+        width (int): image width
+        height (int): image height
+
+    Returns:
+        array: processed depth
+    """
+    depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+    depth_resized = cv2.resize(
+        depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+    )
+
+    return depth_resized
+
+def write_depth(path, depth, bits=1):
+    """Write depth map to pfm and png file.
+
+    Args:
+        path (str): filepath without extension
+        depth (array): depth
+    """
+    write_pfm(path + ".pfm", depth.astype(np.float32))
+
+    depth_min = depth.min()
+    depth_max = depth.max()
+
+    max_val = (2**(8*bits))-1
+
+    if depth_max - depth_min > np.finfo("float").eps:
+        out = max_val * (depth - depth_min) / (depth_max - depth_min)
+    else:
+        out = np.zeros(depth.shape, dtype=depth.type)
+
+    if bits == 1:
+        cv2.imwrite(path + ".png", out.astype("uint8"))
+    elif bits == 2:
+        cv2.imwrite(path + ".png", out.astype("uint16"))
+
+    return
diff --git a/annotator/mlsd/__init__.py b/annotator/mlsd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..42af28c682e781b30f691f65a475b53c9f3adc8b
--- /dev/null
+++ b/annotator/mlsd/__init__.py
@@ -0,0 +1,39 @@
+import cv2
+import numpy as np
+import torch
+import os
+
+from einops import rearrange
+from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
+from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
+from .utils import pred_lines
+
+from annotator.util import annotator_ckpts_path
+
+
+remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth"
+
+
+class MLSDdetector:
+    def __init__(self):
+        model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
+        if not os.path.exists(model_path):
+            from basicsr.utils.download_util import load_file_from_url
+            load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
+        model = MobileV2_MLSD_Large()
+        model.load_state_dict(torch.load(model_path), strict=True)
+        self.model = model.cuda().eval()
+
+    def __call__(self, input_image, thr_v, thr_d):
+        assert input_image.ndim == 3
+        img = input_image
+        img_output = np.zeros_like(img)
+        try:
+            with torch.no_grad():
+                lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
+                for line in lines:
+                    x_start, y_start, x_end, y_end = [int(val) for val in line]
+                    cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
+        except Exception as e:
+            pass
+        return img_output[:, :, 0]
diff --git a/annotator/mlsd/models/mbv2_mlsd_large.py b/annotator/mlsd/models/mbv2_mlsd_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603
--- /dev/null
+++ b/annotator/mlsd/models/mbv2_mlsd_large.py
@@ -0,0 +1,292 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from  torch.nn import  functional as F
+
+
+class BlockTypeA(nn.Module):
+    def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+        super(BlockTypeA, self).__init__()
+        self.conv1 = nn.Sequential(
+            nn.Conv2d(in_c2, out_c2, kernel_size=1),
+            nn.BatchNorm2d(out_c2),
+            nn.ReLU(inplace=True)
+        )
+        self.conv2 = nn.Sequential(
+            nn.Conv2d(in_c1, out_c1, kernel_size=1),
+            nn.BatchNorm2d(out_c1),
+            nn.ReLU(inplace=True)
+        )
+        self.upscale = upscale
+
+    def forward(self, a, b):
+        b = self.conv1(b)
+        a = self.conv2(a)
+        if self.upscale:
+             b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+        return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+    def __init__(self, in_c, out_c):
+        super(BlockTypeB, self).__init__()
+        self.conv1 = nn.Sequential(
+            nn.Conv2d(in_c, in_c,  kernel_size=3, padding=1),
+            nn.BatchNorm2d(in_c),
+            nn.ReLU()
+        )
+        self.conv2 = nn.Sequential(
+            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+            nn.BatchNorm2d(out_c),
+            nn.ReLU()
+        )
+
+    def forward(self, x):
+        x = self.conv1(x) + x
+        x = self.conv2(x)
+        return x
+
+class BlockTypeC(nn.Module):
+    def __init__(self, in_c, out_c):
+        super(BlockTypeC, self).__init__()
+        self.conv1 = nn.Sequential(
+            nn.Conv2d(in_c, in_c,  kernel_size=3, padding=5, dilation=5),
+            nn.BatchNorm2d(in_c),
+            nn.ReLU()
+        )
+        self.conv2 = nn.Sequential(
+            nn.Conv2d(in_c, in_c,  kernel_size=3, padding=1),
+            nn.BatchNorm2d(in_c),
+            nn.ReLU()
+        )
+        self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = self.conv3(x)
+        return x
+
+def _make_divisible(v, divisor, min_value=None):
+    """
+    This function is taken from the original tf repo.
+    It ensures that all layers have a channel number that is divisible by 8
+    It can be seen here:
+    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+    :param v:
+    :param divisor:
+    :param min_value:
+    :return:
+    """
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than 10%.
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+        self.channel_pad = out_planes - in_planes
+        self.stride = stride
+        #padding = (kernel_size - 1) // 2
+
+        # TFLite uses slightly different padding than PyTorch
+        if stride == 2:
+            padding = 0
+        else:
+            padding = (kernel_size - 1) // 2
+
+        super(ConvBNReLU, self).__init__(
+            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+            nn.BatchNorm2d(out_planes),
+            nn.ReLU6(inplace=True)
+        )
+        self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+    def forward(self, x):
+        # TFLite uses  different padding
+        if self.stride == 2:
+            x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+            #print(x.shape)
+
+        for module in self:
+            if not isinstance(module, nn.MaxPool2d):
+                x = module(x)
+        return x
+
+
+class InvertedResidual(nn.Module):
+    def __init__(self, inp, oup, stride, expand_ratio):
+        super(InvertedResidual, self).__init__()
+        self.stride = stride
+        assert stride in [1, 2]
+
+        hidden_dim = int(round(inp * expand_ratio))
+        self.use_res_connect = self.stride == 1 and inp == oup
+
+        layers = []
+        if expand_ratio != 1:
+            # pw
+            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+        layers.extend([
+            # dw
+            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+            # pw-linear
+            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+            nn.BatchNorm2d(oup),
+        ])
+        self.conv = nn.Sequential(*layers)
+
+    def forward(self, x):
+        if self.use_res_connect:
+            return x + self.conv(x)
+        else:
+            return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+    def __init__(self, pretrained=True):
+        """
+        MobileNet V2 main class
+        Args:
+            num_classes (int): Number of classes
+            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+            inverted_residual_setting: Network structure
+            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+            Set to 1 to turn off rounding
+            block: Module specifying inverted residual building block for mobilenet
+        """
+        super(MobileNetV2, self).__init__()
+
+        block = InvertedResidual
+        input_channel = 32
+        last_channel = 1280
+        width_mult = 1.0
+        round_nearest = 8
+
+        inverted_residual_setting = [
+            # t, c, n, s
+            [1, 16, 1, 1],
+            [6, 24, 2, 2],
+            [6, 32, 3, 2],
+            [6, 64, 4, 2],
+            [6, 96, 3, 1],
+            #[6, 160, 3, 2],
+            #[6, 320, 1, 1],
+        ]
+
+        # only check the first element, assuming user knows t,c,n,s are required
+        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+            raise ValueError("inverted_residual_setting should be non-empty "
+                             "or a 4-element list, got {}".format(inverted_residual_setting))
+
+        # building first layer
+        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+        features = [ConvBNReLU(4, input_channel, stride=2)]
+        # building inverted residual blocks
+        for t, c, n, s in inverted_residual_setting:
+            output_channel = _make_divisible(c * width_mult, round_nearest)
+            for i in range(n):
+                stride = s if i == 0 else 1
+                features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+                input_channel = output_channel
+
+        self.features = nn.Sequential(*features)
+        self.fpn_selected = [1, 3, 6, 10, 13]
+        # weight initialization
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out')
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.ones_(m.weight)
+                nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.zeros_(m.bias)
+        if pretrained:
+           self._load_pretrained_model()
+
+    def _forward_impl(self, x):
+        # This exists since TorchScript doesn't support inheritance, so the superclass method
+        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+        fpn_features = []
+        for i, f in enumerate(self.features):
+            if i > self.fpn_selected[-1]:
+                break
+            x = f(x)
+            if i in self.fpn_selected:
+                fpn_features.append(x)
+
+        c1, c2, c3, c4, c5 = fpn_features
+        return c1, c2, c3, c4, c5
+
+
+    def forward(self, x):
+        return self._forward_impl(x)
+
+    def _load_pretrained_model(self):
+        pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+        model_dict = {}
+        state_dict = self.state_dict()
+        for k, v in pretrain_dict.items():
+            if k in state_dict:
+                model_dict[k] = v
+        state_dict.update(model_dict)
+        self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Large(nn.Module):
+    def __init__(self):
+        super(MobileV2_MLSD_Large, self).__init__()
+
+        self.backbone = MobileNetV2(pretrained=False)
+        ## A, B
+        self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
+                                  out_c1= 64, out_c2=64,
+                                  upscale=False)
+        self.block16 = BlockTypeB(128, 64)
+
+        ## A, B
+        self.block17 = BlockTypeA(in_c1 = 32,  in_c2 = 64,
+                                  out_c1= 64,  out_c2= 64)
+        self.block18 = BlockTypeB(128, 64)
+
+        ## A, B
+        self.block19 = BlockTypeA(in_c1=24, in_c2=64,
+                                  out_c1=64, out_c2=64)
+        self.block20 = BlockTypeB(128, 64)
+
+        ## A, B, C
+        self.block21 = BlockTypeA(in_c1=16, in_c2=64,
+                                  out_c1=64, out_c2=64)
+        self.block22 = BlockTypeB(128, 64)
+
+        self.block23 = BlockTypeC(64, 16)
+
+    def forward(self, x):
+        c1, c2, c3, c4, c5 = self.backbone(x)
+
+        x = self.block15(c4, c5)
+        x = self.block16(x)
+
+        x = self.block17(c3, x)
+        x = self.block18(x)
+
+        x = self.block19(c2, x)
+        x = self.block20(x)
+
+        x = self.block21(c1, x)
+        x = self.block22(x)
+        x = self.block23(x)
+        x = x[:, 7:, :, :]
+
+        return x
\ No newline at end of file
diff --git a/annotator/mlsd/models/mbv2_mlsd_tiny.py b/annotator/mlsd/models/mbv2_mlsd_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83
--- /dev/null
+++ b/annotator/mlsd/models/mbv2_mlsd_tiny.py
@@ -0,0 +1,275 @@
+import os
+import sys
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+from  torch.nn import  functional as F
+
+
+class BlockTypeA(nn.Module):
+    def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
+        super(BlockTypeA, self).__init__()
+        self.conv1 = nn.Sequential(
+            nn.Conv2d(in_c2, out_c2, kernel_size=1),
+            nn.BatchNorm2d(out_c2),
+            nn.ReLU(inplace=True)
+        )
+        self.conv2 = nn.Sequential(
+            nn.Conv2d(in_c1, out_c1, kernel_size=1),
+            nn.BatchNorm2d(out_c1),
+            nn.ReLU(inplace=True)
+        )
+        self.upscale = upscale
+
+    def forward(self, a, b):
+        b = self.conv1(b)
+        a = self.conv2(a)
+        b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
+        return torch.cat((a, b), dim=1)
+
+
+class BlockTypeB(nn.Module):
+    def __init__(self, in_c, out_c):
+        super(BlockTypeB, self).__init__()
+        self.conv1 = nn.Sequential(
+            nn.Conv2d(in_c, in_c,  kernel_size=3, padding=1),
+            nn.BatchNorm2d(in_c),
+            nn.ReLU()
+        )
+        self.conv2 = nn.Sequential(
+            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
+            nn.BatchNorm2d(out_c),
+            nn.ReLU()
+        )
+
+    def forward(self, x):
+        x = self.conv1(x) + x
+        x = self.conv2(x)
+        return x
+
+class BlockTypeC(nn.Module):
+    def __init__(self, in_c, out_c):
+        super(BlockTypeC, self).__init__()
+        self.conv1 = nn.Sequential(
+            nn.Conv2d(in_c, in_c,  kernel_size=3, padding=5, dilation=5),
+            nn.BatchNorm2d(in_c),
+            nn.ReLU()
+        )
+        self.conv2 = nn.Sequential(
+            nn.Conv2d(in_c, in_c,  kernel_size=3, padding=1),
+            nn.BatchNorm2d(in_c),
+            nn.ReLU()
+        )
+        self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = self.conv3(x)
+        return x
+
+def _make_divisible(v, divisor, min_value=None):
+    """
+    This function is taken from the original tf repo.
+    It ensures that all layers have a channel number that is divisible by 8
+    It can be seen here:
+    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
+    :param v:
+    :param divisor:
+    :param min_value:
+    :return:
+    """
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than 10%.
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+class ConvBNReLU(nn.Sequential):
+    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
+        self.channel_pad = out_planes - in_planes
+        self.stride = stride
+        #padding = (kernel_size - 1) // 2
+
+        # TFLite uses slightly different padding than PyTorch
+        if stride == 2:
+            padding = 0
+        else:
+            padding = (kernel_size - 1) // 2
+
+        super(ConvBNReLU, self).__init__(
+            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
+            nn.BatchNorm2d(out_planes),
+            nn.ReLU6(inplace=True)
+        )
+        self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
+
+
+    def forward(self, x):
+        # TFLite uses  different padding
+        if self.stride == 2:
+            x = F.pad(x, (0, 1, 0, 1), "constant", 0)
+            #print(x.shape)
+
+        for module in self:
+            if not isinstance(module, nn.MaxPool2d):
+                x = module(x)
+        return x
+
+
+class InvertedResidual(nn.Module):
+    def __init__(self, inp, oup, stride, expand_ratio):
+        super(InvertedResidual, self).__init__()
+        self.stride = stride
+        assert stride in [1, 2]
+
+        hidden_dim = int(round(inp * expand_ratio))
+        self.use_res_connect = self.stride == 1 and inp == oup
+
+        layers = []
+        if expand_ratio != 1:
+            # pw
+            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
+        layers.extend([
+            # dw
+            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
+            # pw-linear
+            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
+            nn.BatchNorm2d(oup),
+        ])
+        self.conv = nn.Sequential(*layers)
+
+    def forward(self, x):
+        if self.use_res_connect:
+            return x + self.conv(x)
+        else:
+            return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+    def __init__(self, pretrained=True):
+        """
+        MobileNet V2 main class
+        Args:
+            num_classes (int): Number of classes
+            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
+            inverted_residual_setting: Network structure
+            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
+            Set to 1 to turn off rounding
+            block: Module specifying inverted residual building block for mobilenet
+        """
+        super(MobileNetV2, self).__init__()
+
+        block = InvertedResidual
+        input_channel = 32
+        last_channel = 1280
+        width_mult = 1.0
+        round_nearest = 8
+
+        inverted_residual_setting = [
+            # t, c, n, s
+            [1, 16, 1, 1],
+            [6, 24, 2, 2],
+            [6, 32, 3, 2],
+            [6, 64, 4, 2],
+            #[6, 96, 3, 1],
+            #[6, 160, 3, 2],
+            #[6, 320, 1, 1],
+        ]
+
+        # only check the first element, assuming user knows t,c,n,s are required
+        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
+            raise ValueError("inverted_residual_setting should be non-empty "
+                             "or a 4-element list, got {}".format(inverted_residual_setting))
+
+        # building first layer
+        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
+        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
+        features = [ConvBNReLU(4, input_channel, stride=2)]
+        # building inverted residual blocks
+        for t, c, n, s in inverted_residual_setting:
+            output_channel = _make_divisible(c * width_mult, round_nearest)
+            for i in range(n):
+                stride = s if i == 0 else 1
+                features.append(block(input_channel, output_channel, stride, expand_ratio=t))
+                input_channel = output_channel
+        self.features = nn.Sequential(*features)
+
+        self.fpn_selected = [3, 6, 10]
+        # weight initialization
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out')
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.ones_(m.weight)
+                nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.Linear):
+                nn.init.normal_(m.weight, 0, 0.01)
+                nn.init.zeros_(m.bias)
+
+        #if pretrained:
+        #    self._load_pretrained_model()
+
+    def _forward_impl(self, x):
+        # This exists since TorchScript doesn't support inheritance, so the superclass method
+        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
+        fpn_features = []
+        for i, f in enumerate(self.features):
+            if i > self.fpn_selected[-1]:
+                break
+            x = f(x)
+            if i in self.fpn_selected:
+                fpn_features.append(x)
+
+        c2, c3, c4 = fpn_features
+        return c2, c3, c4
+
+
+    def forward(self, x):
+        return self._forward_impl(x)
+
+    def _load_pretrained_model(self):
+        pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
+        model_dict = {}
+        state_dict = self.state_dict()
+        for k, v in pretrain_dict.items():
+            if k in state_dict:
+                model_dict[k] = v
+        state_dict.update(model_dict)
+        self.load_state_dict(state_dict)
+
+
+class MobileV2_MLSD_Tiny(nn.Module):
+    def __init__(self):
+        super(MobileV2_MLSD_Tiny, self).__init__()
+
+        self.backbone = MobileNetV2(pretrained=True)
+
+        self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
+                                  out_c1= 64, out_c2=64)
+        self.block13 = BlockTypeB(128, 64)
+
+        self.block14 = BlockTypeA(in_c1 = 24,  in_c2 = 64,
+                                  out_c1= 32,  out_c2= 32)
+        self.block15 = BlockTypeB(64, 64)
+
+        self.block16 = BlockTypeC(64, 16)
+
+    def forward(self, x):
+        c2, c3, c4 = self.backbone(x)
+
+        x = self.block12(c3, c4)
+        x = self.block13(x)
+        x = self.block14(c2, x)
+        x = self.block15(x)
+        x = self.block16(x)
+        x = x[:, 7:, :, :]
+        #print(x.shape)
+        x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
+
+        return x
\ No newline at end of file
diff --git a/annotator/mlsd/utils.py b/annotator/mlsd/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae3cf9420a33a4abae27c48ac4b90938c7d63cc3
--- /dev/null
+++ b/annotator/mlsd/utils.py
@@ -0,0 +1,580 @@
+'''
+modified by  lihaoweicv
+pytorch version
+'''
+
+'''
+M-LSD
+Copyright 2021-present NAVER Corp.
+Apache License v2.0
+'''
+
+import os
+import numpy as np
+import cv2
+import torch
+from  torch.nn import  functional as F
+
+
+def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
+    '''
+    tpMap:
+    center: tpMap[1, 0, :, :]
+    displacement: tpMap[1, 1:5, :, :]
+    '''
+    b, c, h, w = tpMap.shape
+    assert  b==1, 'only support bsize==1'
+    displacement = tpMap[:, 1:5, :, :][0]
+    center = tpMap[:, 0, :, :]
+    heat = torch.sigmoid(center)
+    hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
+    keep = (hmax == heat).float()
+    heat = heat * keep
+    heat = heat.reshape(-1, )
+
+    scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
+    yy = torch.floor_divide(indices, w).unsqueeze(-1)
+    xx = torch.fmod(indices, w).unsqueeze(-1)
+    ptss = torch.cat((yy, xx),dim=-1)
+
+    ptss   = ptss.detach().cpu().numpy()
+    scores = scores.detach().cpu().numpy()
+    displacement = displacement.detach().cpu().numpy()
+    displacement = displacement.transpose((1,2,0))
+    return  ptss, scores, displacement
+
+
+def pred_lines(image, model,
+               input_shape=[512, 512],
+               score_thr=0.10,
+               dist_thr=20.0):
+    h, w, _ = image.shape
+    h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
+
+    resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
+                                    np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+
+    resized_image = resized_image.transpose((2,0,1))
+    batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+    batch_image = (batch_image / 127.5) - 1.0
+
+    batch_image = torch.from_numpy(batch_image).float().cuda()
+    outputs = model(batch_image)
+    pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+    start = vmap[:, :, :2]
+    end = vmap[:, :, 2:]
+    dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+    segments_list = []
+    for center, score in zip(pts, pts_score):
+        y, x = center
+        distance = dist_map[y, x]
+        if score > score_thr and distance > dist_thr:
+            disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+            x_start = x + disp_x_start
+            y_start = y + disp_y_start
+            x_end = x + disp_x_end
+            y_end = y + disp_y_end
+            segments_list.append([x_start, y_start, x_end, y_end])
+
+    lines = 2 * np.array(segments_list)  # 256 > 512
+    lines[:, 0] = lines[:, 0] * w_ratio
+    lines[:, 1] = lines[:, 1] * h_ratio
+    lines[:, 2] = lines[:, 2] * w_ratio
+    lines[:, 3] = lines[:, 3] * h_ratio
+
+    return lines
+
+
+def pred_squares(image,
+                 model,
+                 input_shape=[512, 512],
+                 params={'score': 0.06,
+                         'outside_ratio': 0.28,
+                         'inside_ratio': 0.45,
+                         'w_overlap': 0.0,
+                         'w_degree': 1.95,
+                         'w_length': 0.0,
+                         'w_area': 1.86,
+                         'w_center': 0.14}):
+    '''
+    shape = [height, width]
+    '''
+    h, w, _ = image.shape
+    original_shape = [h, w]
+
+    resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
+                                    np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
+    resized_image = resized_image.transpose((2, 0, 1))
+    batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
+    batch_image = (batch_image / 127.5) - 1.0
+
+    batch_image = torch.from_numpy(batch_image).float().cuda()
+    outputs = model(batch_image)
+
+    pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
+    start = vmap[:, :, :2]  # (x, y)
+    end = vmap[:, :, 2:]  # (x, y)
+    dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
+
+    junc_list = []
+    segments_list = []
+    for junc, score in zip(pts, pts_score):
+        y, x = junc
+        distance = dist_map[y, x]
+        if score > params['score'] and distance > 20.0:
+            junc_list.append([x, y])
+            disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
+            d_arrow = 1.0
+            x_start = x + d_arrow * disp_x_start
+            y_start = y + d_arrow * disp_y_start
+            x_end = x + d_arrow * disp_x_end
+            y_end = y + d_arrow * disp_y_end
+            segments_list.append([x_start, y_start, x_end, y_end])
+
+    segments = np.array(segments_list)
+
+    ####### post processing for squares
+    # 1. get unique lines
+    point = np.array([[0, 0]])
+    point = point[0]
+    start = segments[:, :2]
+    end = segments[:, 2:]
+    diff = start - end
+    a = diff[:, 1]
+    b = -diff[:, 0]
+    c = a * start[:, 0] + b * start[:, 1]
+
+    d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
+    theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
+    theta[theta < 0.0] += 180
+    hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
+
+    d_quant = 1
+    theta_quant = 2
+    hough[:, 0] //= d_quant
+    hough[:, 1] //= theta_quant
+    _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
+
+    acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
+    idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
+    yx_indices = hough[indices, :].astype('int32')
+    acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
+    idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
+
+    acc_map_np = acc_map
+    # acc_map = acc_map[None, :, :, None]
+    #
+    # ### fast suppression using tensorflow op
+    # acc_map = tf.constant(acc_map, dtype=tf.float32)
+    # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
+    # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
+    # flatten_acc_map = tf.reshape(acc_map, [1, -1])
+    # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
+    # _, h, w, _ = acc_map.shape
+    # y = tf.expand_dims(topk_indices // w, axis=-1)
+    # x = tf.expand_dims(topk_indices % w, axis=-1)
+    # yx = tf.concat([y, x], axis=-1)
+
+    ### fast suppression using pytorch op
+    acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
+    _,_, h, w = acc_map.shape
+    max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
+    acc_map = acc_map * ( (acc_map == max_acc_map).float() )
+    flatten_acc_map = acc_map.reshape([-1, ])
+
+    scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
+    yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
+    xx = torch.fmod(indices, w).unsqueeze(-1)
+    yx = torch.cat((yy, xx), dim=-1)
+
+    yx = yx.detach().cpu().numpy()
+
+    topk_values = scores.detach().cpu().numpy()
+    indices = idx_map[yx[:, 0], yx[:, 1]]
+    basis = 5 // 2
+
+    merged_segments = []
+    for yx_pt, max_indice, value in zip(yx, indices, topk_values):
+        y, x = yx_pt
+        if max_indice == -1 or value == 0:
+            continue
+        segment_list = []
+        for y_offset in range(-basis, basis + 1):
+            for x_offset in range(-basis, basis + 1):
+                indice = idx_map[y + y_offset, x + x_offset]
+                cnt = int(acc_map_np[y + y_offset, x + x_offset])
+                if indice != -1:
+                    segment_list.append(segments[indice])
+                if cnt > 1:
+                    check_cnt = 1
+                    current_hough = hough[indice]
+                    for new_indice, new_hough in enumerate(hough):
+                        if (current_hough == new_hough).all() and indice != new_indice:
+                            segment_list.append(segments[new_indice])
+                            check_cnt += 1
+                        if check_cnt == cnt:
+                            break
+        group_segments = np.array(segment_list).reshape([-1, 2])
+        sorted_group_segments = np.sort(group_segments, axis=0)
+        x_min, y_min = sorted_group_segments[0, :]
+        x_max, y_max = sorted_group_segments[-1, :]
+
+        deg = theta[max_indice]
+        if deg >= 90:
+            merged_segments.append([x_min, y_max, x_max, y_min])
+        else:
+            merged_segments.append([x_min, y_min, x_max, y_max])
+
+    # 2. get intersections
+    new_segments = np.array(merged_segments)  # (x1, y1, x2, y2)
+    start = new_segments[:, :2]  # (x1, y1)
+    end = new_segments[:, 2:]  # (x2, y2)
+    new_centers = (start + end) / 2.0
+    diff = start - end
+    dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
+
+    # ax + by = c
+    a = diff[:, 1]
+    b = -diff[:, 0]
+    c = a * start[:, 0] + b * start[:, 1]
+    pre_det = a[:, None] * b[None, :]
+    det = pre_det - np.transpose(pre_det)
+
+    pre_inter_y = a[:, None] * c[None, :]
+    inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
+    pre_inter_x = c[:, None] * b[None, :]
+    inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
+    inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
+
+    # 3. get corner information
+    # 3.1 get distance
+    '''
+    dist_segments:
+        | dist(0), dist(1), dist(2), ...|
+    dist_inter_to_segment1:
+        | dist(inter,0), dist(inter,0), dist(inter,0), ... |
+        | dist(inter,1), dist(inter,1), dist(inter,1), ... |
+        ...
+    dist_inter_to_semgnet2:
+        | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+        | dist(inter,0), dist(inter,1), dist(inter,2), ... |
+        ...
+    '''
+
+    dist_inter_to_segment1_start = np.sqrt(
+        np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True))  # [n_batch, n_batch, 1]
+    dist_inter_to_segment1_end = np.sqrt(
+        np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True))  # [n_batch, n_batch, 1]
+    dist_inter_to_segment2_start = np.sqrt(
+        np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True))  # [n_batch, n_batch, 1]
+    dist_inter_to_segment2_end = np.sqrt(
+        np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True))  # [n_batch, n_batch, 1]
+
+    # sort ascending
+    dist_inter_to_segment1 = np.sort(
+        np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
+        axis=-1)  # [n_batch, n_batch, 2]
+    dist_inter_to_segment2 = np.sort(
+        np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
+        axis=-1)  # [n_batch, n_batch, 2]
+
+    # 3.2 get degree
+    inter_to_start = new_centers[:, None, :] - inter_pts
+    deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
+    deg_inter_to_start[deg_inter_to_start < 0.0] += 360
+    inter_to_end = new_centers[None, :, :] - inter_pts
+    deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
+    deg_inter_to_end[deg_inter_to_end < 0.0] += 360
+
+    '''
+    B -- G
+    |    |
+    C -- R
+    B : blue / G: green / C: cyan / R: red
+
+    0 -- 1
+    |    |
+    3 -- 2
+    '''
+    # rename variables
+    deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
+    # sort deg ascending
+    deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
+
+    deg_diff_map = np.abs(deg1_map - deg2_map)
+    # we only consider the smallest degree of intersect
+    deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
+
+    # define available degree range
+    deg_range = [60, 120]
+
+    corner_dict = {corner_info: [] for corner_info in range(4)}
+    inter_points = []
+    for i in range(inter_pts.shape[0]):
+        for j in range(i + 1, inter_pts.shape[1]):
+            # i, j > line index, always i < j
+            x, y = inter_pts[i, j, :]
+            deg1, deg2 = deg_sort[i, j, :]
+            deg_diff = deg_diff_map[i, j]
+
+            check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
+
+            outside_ratio = params['outside_ratio']  # over ratio >>> drop it!
+            inside_ratio = params['inside_ratio']  # over ratio >>> drop it!
+            check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
+                               dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
+                              (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
+                               dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
+                             ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
+                               dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
+                              (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
+                               dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
+
+            if check_degree and check_distance:
+                corner_info = None
+
+                if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
+                        (deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
+                    corner_info, color_info = 0, 'blue'
+                elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
+                    corner_info, color_info = 1, 'green'
+                elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
+                    corner_info, color_info = 2, 'black'
+                elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
+                        (deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
+                    corner_info, color_info = 3, 'cyan'
+                else:
+                    corner_info, color_info = 4, 'red'  # we don't use it
+                    continue
+
+                corner_dict[corner_info].append([x, y, i, j])
+                inter_points.append([x, y])
+
+    square_list = []
+    connect_list = []
+    segments_list = []
+    for corner0 in corner_dict[0]:
+        for corner1 in corner_dict[1]:
+            connect01 = False
+            for corner0_line in corner0[2:]:
+                if corner0_line in corner1[2:]:
+                    connect01 = True
+                    break
+            if connect01:
+                for corner2 in corner_dict[2]:
+                    connect12 = False
+                    for corner1_line in corner1[2:]:
+                        if corner1_line in corner2[2:]:
+                            connect12 = True
+                            break
+                    if connect12:
+                        for corner3 in corner_dict[3]:
+                            connect23 = False
+                            for corner2_line in corner2[2:]:
+                                if corner2_line in corner3[2:]:
+                                    connect23 = True
+                                    break
+                            if connect23:
+                                for corner3_line in corner3[2:]:
+                                    if corner3_line in corner0[2:]:
+                                        # SQUARE!!!
+                                        '''
+                                        0 -- 1
+                                        |    |
+                                        3 -- 2
+                                        square_list:
+                                            order: 0 > 1 > 2 > 3
+                                            | x0, y0, x1, y1, x2, y2, x3, y3 |
+                                            | x0, y0, x1, y1, x2, y2, x3, y3 |
+                                            ...
+                                        connect_list:
+                                            order: 01 > 12 > 23 > 30
+                                            | line_idx01, line_idx12, line_idx23, line_idx30 |
+                                            | line_idx01, line_idx12, line_idx23, line_idx30 |
+                                            ...
+                                        segments_list:
+                                            order: 0 > 1 > 2 > 3
+                                            | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+                                            | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
+                                            ...
+                                        '''
+                                        square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
+                                        connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
+                                        segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
+
+    def check_outside_inside(segments_info, connect_idx):
+        # return 'outside or inside', min distance, cover_param, peri_param
+        if connect_idx == segments_info[0]:
+            check_dist_mat = dist_inter_to_segment1
+        else:
+            check_dist_mat = dist_inter_to_segment2
+
+        i, j = segments_info
+        min_dist, max_dist = check_dist_mat[i, j, :]
+        connect_dist = dist_segments[connect_idx]
+        if max_dist > connect_dist:
+            return 'outside', min_dist, 0, 1
+        else:
+            return 'inside', min_dist, -1, -1
+
+    top_square = None
+
+    try:
+        map_size = input_shape[0] / 2
+        squares = np.array(square_list).reshape([-1, 4, 2])
+        score_array = []
+        connect_array = np.array(connect_list)
+        segments_array = np.array(segments_list).reshape([-1, 4, 2])
+
+        # get degree of corners:
+        squares_rollup = np.roll(squares, 1, axis=1)
+        squares_rolldown = np.roll(squares, -1, axis=1)
+        vec1 = squares_rollup - squares
+        normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
+        vec2 = squares_rolldown - squares
+        normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
+        inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1)  # [n_squares, 4]
+        squares_degree = np.arccos(inner_products) * 180 / np.pi  # [n_squares, 4]
+
+        # get square score
+        overlap_scores = []
+        degree_scores = []
+        length_scores = []
+
+        for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
+            '''
+            0 -- 1
+            |    |
+            3 -- 2
+
+            # segments: [4, 2]
+            # connects: [4]
+            '''
+
+            ###################################### OVERLAP SCORES
+            cover = 0
+            perimeter = 0
+            # check 0 > 1 > 2 > 3
+            square_length = []
+
+            for start_idx in range(4):
+                end_idx = (start_idx + 1) % 4
+
+                connect_idx = connects[start_idx]  # segment idx of segment01
+                start_segments = segments[start_idx]
+                end_segments = segments[end_idx]
+
+                start_point = square[start_idx]
+                end_point = square[end_idx]
+
+                # check whether outside or inside
+                start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
+                                                                                                      connect_idx)
+                end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
+
+                cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
+                perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
+
+                square_length.append(
+                    dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
+
+            overlap_scores.append(cover / perimeter)
+            ######################################
+            ###################################### DEGREE SCORES
+            '''
+            deg0 vs deg2
+            deg1 vs deg3
+            '''
+            deg0, deg1, deg2, deg3 = degree
+            deg_ratio1 = deg0 / deg2
+            if deg_ratio1 > 1.0:
+                deg_ratio1 = 1 / deg_ratio1
+            deg_ratio2 = deg1 / deg3
+            if deg_ratio2 > 1.0:
+                deg_ratio2 = 1 / deg_ratio2
+            degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
+            ######################################
+            ###################################### LENGTH SCORES
+            '''
+            len0 vs len2
+            len1 vs len3
+            '''
+            len0, len1, len2, len3 = square_length
+            len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
+            len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
+            length_scores.append((len_ratio1 + len_ratio2) / 2)
+
+            ######################################
+
+        overlap_scores = np.array(overlap_scores)
+        overlap_scores /= np.max(overlap_scores)
+
+        degree_scores = np.array(degree_scores)
+        # degree_scores /= np.max(degree_scores)
+
+        length_scores = np.array(length_scores)
+
+        ###################################### AREA SCORES
+        area_scores = np.reshape(squares, [-1, 4, 2])
+        area_x = area_scores[:, :, 0]
+        area_y = area_scores[:, :, 1]
+        correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
+        area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
+        area_scores = 0.5 * np.abs(area_scores + correction)
+        area_scores /= (map_size * map_size)  # np.max(area_scores)
+        ######################################
+
+        ###################################### CENTER SCORES
+        centers = np.array([[256 // 2, 256 // 2]], dtype='float32')  # [1, 2]
+        # squares: [n, 4, 2]
+        square_centers = np.mean(squares, axis=1)  # [n, 2]
+        center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
+        center_scores = center2center / (map_size / np.sqrt(2.0))
+
+        '''
+        score_w = [overlap, degree, area, center, length]
+        '''
+        score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
+        score_array = params['w_overlap'] * overlap_scores \
+                      + params['w_degree'] * degree_scores \
+                      + params['w_area'] * area_scores \
+                      - params['w_center'] * center_scores \
+                      + params['w_length'] * length_scores
+
+        best_square = []
+
+        sorted_idx = np.argsort(score_array)[::-1]
+        score_array = score_array[sorted_idx]
+        squares = squares[sorted_idx]
+
+    except Exception as e:
+        pass
+
+    '''return list
+    merged_lines, squares, scores
+    '''
+
+    try:
+        new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
+        new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
+        new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
+        new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
+    except:
+        new_segments = []
+
+    try:
+        squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
+        squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
+    except:
+        squares = []
+        score_array = []
+
+    try:
+        inter_points = np.array(inter_points)
+        inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
+        inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
+    except:
+        inter_points = []
+
+    return new_segments, squares, score_array, inter_points
diff --git a/annotator/openpose/__init__.py b/annotator/openpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c26f1b37dae854f51da938da2fa67a8ef48ce5a
--- /dev/null
+++ b/annotator/openpose/__init__.py
@@ -0,0 +1,44 @@
+import os
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+import torch
+import numpy as np
+from . import util
+from .body import Body
+from .hand import Hand
+from annotator.util import annotator_ckpts_path
+
+
+body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth"
+hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth"
+
+
+class OpenposeDetector:
+    def __init__(self):
+        body_modelpath = os.path.join(annotator_ckpts_path, "body_pose_model.pth")
+        hand_modelpath = os.path.join(annotator_ckpts_path, "hand_pose_model.pth")
+
+        if not os.path.exists(hand_modelpath):
+            from basicsr.utils.download_util import load_file_from_url
+            load_file_from_url(body_model_path, model_dir=annotator_ckpts_path)
+            load_file_from_url(hand_model_path, model_dir=annotator_ckpts_path)
+
+        self.body_estimation = Body(body_modelpath)
+        self.hand_estimation = Hand(hand_modelpath)
+
+    def __call__(self, oriImg, hand=False):
+        oriImg = oriImg[:, :, ::-1].copy()
+        with torch.no_grad():
+            candidate, subset = self.body_estimation(oriImg)
+            canvas = np.zeros_like(oriImg)
+            canvas = util.draw_bodypose(canvas, candidate, subset)
+            if hand:
+                hands_list = util.handDetect(candidate, subset, oriImg)
+                all_hand_peaks = []
+                for x, y, w, is_left in hands_list:
+                    peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :])
+                    peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x)
+                    peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y)
+                    all_hand_peaks.append(peaks)
+                canvas = util.draw_handpose(canvas, all_hand_peaks)
+            return canvas, dict(candidate=candidate.tolist(), subset=subset.tolist())
diff --git a/annotator/openpose/body.py b/annotator/openpose/body.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c3cf7a388b4ac81004524e64125e383bdd455bd
--- /dev/null
+++ b/annotator/openpose/body.py
@@ -0,0 +1,219 @@
+import cv2
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from torchvision import transforms
+
+from . import util
+from .model import bodypose_model
+
+class Body(object):
+    def __init__(self, model_path):
+        self.model = bodypose_model()
+        if torch.cuda.is_available():
+            self.model = self.model.cuda()
+            print('cuda')
+        model_dict = util.transfer(self.model, torch.load(model_path))
+        self.model.load_state_dict(model_dict)
+        self.model.eval()
+
+    def __call__(self, oriImg):
+        # scale_search = [0.5, 1.0, 1.5, 2.0]
+        scale_search = [0.5]
+        boxsize = 368
+        stride = 8
+        padValue = 128
+        thre1 = 0.1
+        thre2 = 0.05
+        multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+        heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
+        paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+        for m in range(len(multiplier)):
+            scale = multiplier[m]
+            imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+            imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+            im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+            im = np.ascontiguousarray(im)
+
+            data = torch.from_numpy(im).float()
+            if torch.cuda.is_available():
+                data = data.cuda()
+            # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+            with torch.no_grad():
+                Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
+            Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
+            Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
+
+            # extract outputs, resize, and remove padding
+            # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0))  # output 1 is heatmaps
+            heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0))  # output 1 is heatmaps
+            heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+            heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+            heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+            # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0))  # output 0 is PAFs
+            paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0))  # output 0 is PAFs
+            paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+            paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+            paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+            heatmap_avg += heatmap_avg + heatmap / len(multiplier)
+            paf_avg += + paf / len(multiplier)
+
+        all_peaks = []
+        peak_counter = 0
+
+        for part in range(18):
+            map_ori = heatmap_avg[:, :, part]
+            one_heatmap = gaussian_filter(map_ori, sigma=3)
+
+            map_left = np.zeros(one_heatmap.shape)
+            map_left[1:, :] = one_heatmap[:-1, :]
+            map_right = np.zeros(one_heatmap.shape)
+            map_right[:-1, :] = one_heatmap[1:, :]
+            map_up = np.zeros(one_heatmap.shape)
+            map_up[:, 1:] = one_heatmap[:, :-1]
+            map_down = np.zeros(one_heatmap.shape)
+            map_down[:, :-1] = one_heatmap[:, 1:]
+
+            peaks_binary = np.logical_and.reduce(
+                (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
+            peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0]))  # note reverse
+            peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
+            peak_id = range(peak_counter, peak_counter + len(peaks))
+            peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))]
+
+            all_peaks.append(peaks_with_score_and_id)
+            peak_counter += len(peaks)
+
+        # find connection in the specified sequence, center 29 is in the position 15
+        limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+                   [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+                   [1, 16], [16, 18], [3, 17], [6, 18]]
+        # the middle joints heatmap correpondence
+        mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
+                  [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
+                  [55, 56], [37, 38], [45, 46]]
+
+        connection_all = []
+        special_k = []
+        mid_num = 10
+
+        for k in range(len(mapIdx)):
+            score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
+            candA = all_peaks[limbSeq[k][0] - 1]
+            candB = all_peaks[limbSeq[k][1] - 1]
+            nA = len(candA)
+            nB = len(candB)
+            indexA, indexB = limbSeq[k]
+            if (nA != 0 and nB != 0):
+                connection_candidate = []
+                for i in range(nA):
+                    for j in range(nB):
+                        vec = np.subtract(candB[j][:2], candA[i][:2])
+                        norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
+                        norm = max(0.001, norm)
+                        vec = np.divide(vec, norm)
+
+                        startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
+                                            np.linspace(candA[i][1], candB[j][1], num=mid_num)))
+
+                        vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
+                                          for I in range(len(startend))])
+                        vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
+                                          for I in range(len(startend))])
+
+                        score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
+                        score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
+                            0.5 * oriImg.shape[0] / norm - 1, 0)
+                        criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
+                        criterion2 = score_with_dist_prior > 0
+                        if criterion1 and criterion2:
+                            connection_candidate.append(
+                                [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
+
+                connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
+                connection = np.zeros((0, 5))
+                for c in range(len(connection_candidate)):
+                    i, j, s = connection_candidate[c][0:3]
+                    if (i not in connection[:, 3] and j not in connection[:, 4]):
+                        connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
+                        if (len(connection) >= min(nA, nB)):
+                            break
+
+                connection_all.append(connection)
+            else:
+                special_k.append(k)
+                connection_all.append([])
+
+        # last number in each row is the total parts number of that person
+        # the second last number in each row is the score of the overall configuration
+        subset = -1 * np.ones((0, 20))
+        candidate = np.array([item for sublist in all_peaks for item in sublist])
+
+        for k in range(len(mapIdx)):
+            if k not in special_k:
+                partAs = connection_all[k][:, 0]
+                partBs = connection_all[k][:, 1]
+                indexA, indexB = np.array(limbSeq[k]) - 1
+
+                for i in range(len(connection_all[k])):  # = 1:size(temp,1)
+                    found = 0
+                    subset_idx = [-1, -1]
+                    for j in range(len(subset)):  # 1:size(subset,1):
+                        if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
+                            subset_idx[found] = j
+                            found += 1
+
+                    if found == 1:
+                        j = subset_idx[0]
+                        if subset[j][indexB] != partBs[i]:
+                            subset[j][indexB] = partBs[i]
+                            subset[j][-1] += 1
+                            subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+                    elif found == 2:  # if found 2 and disjoint, merge them
+                        j1, j2 = subset_idx
+                        membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
+                        if len(np.nonzero(membership == 2)[0]) == 0:  # merge
+                            subset[j1][:-2] += (subset[j2][:-2] + 1)
+                            subset[j1][-2:] += subset[j2][-2:]
+                            subset[j1][-2] += connection_all[k][i][2]
+                            subset = np.delete(subset, j2, 0)
+                        else:  # as like found == 1
+                            subset[j1][indexB] = partBs[i]
+                            subset[j1][-1] += 1
+                            subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+
+                    # if find no partA in the subset, create a new subset
+                    elif not found and k < 17:
+                        row = -1 * np.ones(20)
+                        row[indexA] = partAs[i]
+                        row[indexB] = partBs[i]
+                        row[-1] = 2
+                        row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
+                        subset = np.vstack([subset, row])
+        # delete some rows of subset which has few parts occur
+        deleteIdx = []
+        for i in range(len(subset)):
+            if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
+                deleteIdx.append(i)
+        subset = np.delete(subset, deleteIdx, axis=0)
+
+        # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
+        # candidate: x, y, score, id
+        return candidate, subset
+
+if __name__ == "__main__":
+    body_estimation = Body('../model/body_pose_model.pth')
+
+    test_image = '../images/ski.jpg'
+    oriImg = cv2.imread(test_image)  # B,G,R order
+    candidate, subset = body_estimation(oriImg)
+    canvas = util.draw_bodypose(oriImg, candidate, subset)
+    plt.imshow(canvas[:, :, [2, 1, 0]])
+    plt.show()
diff --git a/annotator/openpose/hand.py b/annotator/openpose/hand.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0bf17165ad7eb225332b51f4a2aa16718664b2
--- /dev/null
+++ b/annotator/openpose/hand.py
@@ -0,0 +1,86 @@
+import cv2
+import json
+import numpy as np
+import math
+import time
+from scipy.ndimage.filters import gaussian_filter
+import matplotlib.pyplot as plt
+import matplotlib
+import torch
+from skimage.measure import label
+
+from .model import handpose_model
+from . import util
+
+class Hand(object):
+    def __init__(self, model_path):
+        self.model = handpose_model()
+        if torch.cuda.is_available():
+            self.model = self.model.cuda()
+            print('cuda')
+        model_dict = util.transfer(self.model, torch.load(model_path))
+        self.model.load_state_dict(model_dict)
+        self.model.eval()
+
+    def __call__(self, oriImg):
+        scale_search = [0.5, 1.0, 1.5, 2.0]
+        # scale_search = [0.5]
+        boxsize = 368
+        stride = 8
+        padValue = 128
+        thre = 0.05
+        multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+        heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
+        # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+        for m in range(len(multiplier)):
+            scale = multiplier[m]
+            imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+            imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+            im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+            im = np.ascontiguousarray(im)
+
+            data = torch.from_numpy(im).float()
+            if torch.cuda.is_available():
+                data = data.cuda()
+            # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+            with torch.no_grad():
+                output = self.model(data).cpu().numpy()
+                # output = self.model(data).numpy()q
+
+            # extract outputs, resize, and remove padding
+            heatmap = np.transpose(np.squeeze(output), (1, 2, 0))  # output 1 is heatmaps
+            heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+            heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+            heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+            heatmap_avg += heatmap / len(multiplier)
+
+        all_peaks = []
+        for part in range(21):
+            map_ori = heatmap_avg[:, :, part]
+            one_heatmap = gaussian_filter(map_ori, sigma=3)
+            binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
+            # 全部小于阈值
+            if np.sum(binary) == 0:
+                all_peaks.append([0, 0])
+                continue
+            label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
+            max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
+            label_img[label_img != max_index] = 0
+            map_ori[label_img == 0] = 0
+
+            y, x = util.npmax(map_ori)
+            all_peaks.append([x, y])
+        return np.array(all_peaks)
+
+if __name__ == "__main__":
+    hand_estimation = Hand('../model/hand_pose_model.pth')
+
+    # test_image = '../images/hand.jpg'
+    test_image = '../images/hand.jpg'
+    oriImg = cv2.imread(test_image)  # B,G,R order
+    peaks = hand_estimation(oriImg)
+    canvas = util.draw_handpose(oriImg, peaks, True)
+    cv2.imshow('', canvas)
+    cv2.waitKey(0)
\ No newline at end of file
diff --git a/annotator/openpose/model.py b/annotator/openpose/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dfc80de827a17beccb9b0f3f7588545be78c9de
--- /dev/null
+++ b/annotator/openpose/model.py
@@ -0,0 +1,219 @@
+import torch
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+def make_layers(block, no_relu_layers):
+    layers = []
+    for layer_name, v in block.items():
+        if 'pool' in layer_name:
+            layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1],
+                                    padding=v[2])
+            layers.append((layer_name, layer))
+        else:
+            conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
+                               kernel_size=v[2], stride=v[3],
+                               padding=v[4])
+            layers.append((layer_name, conv2d))
+            if layer_name not in no_relu_layers:
+                layers.append(('relu_'+layer_name, nn.ReLU(inplace=True)))
+
+    return nn.Sequential(OrderedDict(layers))
+
+class bodypose_model(nn.Module):
+    def __init__(self):
+        super(bodypose_model, self).__init__()
+
+        # these layers have no relu layer
+        no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
+                          'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
+                          'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
+                          'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
+        blocks = {}
+        block0 = OrderedDict([
+                      ('conv1_1', [3, 64, 3, 1, 1]),
+                      ('conv1_2', [64, 64, 3, 1, 1]),
+                      ('pool1_stage1', [2, 2, 0]),
+                      ('conv2_1', [64, 128, 3, 1, 1]),
+                      ('conv2_2', [128, 128, 3, 1, 1]),
+                      ('pool2_stage1', [2, 2, 0]),
+                      ('conv3_1', [128, 256, 3, 1, 1]),
+                      ('conv3_2', [256, 256, 3, 1, 1]),
+                      ('conv3_3', [256, 256, 3, 1, 1]),
+                      ('conv3_4', [256, 256, 3, 1, 1]),
+                      ('pool3_stage1', [2, 2, 0]),
+                      ('conv4_1', [256, 512, 3, 1, 1]),
+                      ('conv4_2', [512, 512, 3, 1, 1]),
+                      ('conv4_3_CPM', [512, 256, 3, 1, 1]),
+                      ('conv4_4_CPM', [256, 128, 3, 1, 1])
+                  ])
+
+
+        # Stage 1
+        block1_1 = OrderedDict([
+                        ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]),
+                        ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
+                        ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]),
+                        ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
+                        ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])
+                    ])
+
+        block1_2 = OrderedDict([
+                        ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]),
+                        ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
+                        ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]),
+                        ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
+                        ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])
+                    ])
+        blocks['block1_1'] = block1_1
+        blocks['block1_2'] = block1_2
+
+        self.model0 = make_layers(block0, no_relu_layers)
+
+        # Stages 2 - 6
+        for i in range(2, 7):
+            blocks['block%d_1' % i] = OrderedDict([
+                    ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
+                    ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
+                    ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])
+                ])
+
+            blocks['block%d_2' % i] = OrderedDict([
+                    ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
+                    ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
+                    ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])
+                ])
+
+        for k in blocks.keys():
+            blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+        self.model1_1 = blocks['block1_1']
+        self.model2_1 = blocks['block2_1']
+        self.model3_1 = blocks['block3_1']
+        self.model4_1 = blocks['block4_1']
+        self.model5_1 = blocks['block5_1']
+        self.model6_1 = blocks['block6_1']
+
+        self.model1_2 = blocks['block1_2']
+        self.model2_2 = blocks['block2_2']
+        self.model3_2 = blocks['block3_2']
+        self.model4_2 = blocks['block4_2']
+        self.model5_2 = blocks['block5_2']
+        self.model6_2 = blocks['block6_2']
+
+
+    def forward(self, x):
+
+        out1 = self.model0(x)
+
+        out1_1 = self.model1_1(out1)
+        out1_2 = self.model1_2(out1)
+        out2 = torch.cat([out1_1, out1_2, out1], 1)
+
+        out2_1 = self.model2_1(out2)
+        out2_2 = self.model2_2(out2)
+        out3 = torch.cat([out2_1, out2_2, out1], 1)
+
+        out3_1 = self.model3_1(out3)
+        out3_2 = self.model3_2(out3)
+        out4 = torch.cat([out3_1, out3_2, out1], 1)
+
+        out4_1 = self.model4_1(out4)
+        out4_2 = self.model4_2(out4)
+        out5 = torch.cat([out4_1, out4_2, out1], 1)
+
+        out5_1 = self.model5_1(out5)
+        out5_2 = self.model5_2(out5)
+        out6 = torch.cat([out5_1, out5_2, out1], 1)
+
+        out6_1 = self.model6_1(out6)
+        out6_2 = self.model6_2(out6)
+
+        return out6_1, out6_2
+
+class handpose_model(nn.Module):
+    def __init__(self):
+        super(handpose_model, self).__init__()
+
+        # these layers have no relu layer
+        no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
+                          'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
+        # stage 1
+        block1_0 = OrderedDict([
+                ('conv1_1', [3, 64, 3, 1, 1]),
+                ('conv1_2', [64, 64, 3, 1, 1]),
+                ('pool1_stage1', [2, 2, 0]),
+                ('conv2_1', [64, 128, 3, 1, 1]),
+                ('conv2_2', [128, 128, 3, 1, 1]),
+                ('pool2_stage1', [2, 2, 0]),
+                ('conv3_1', [128, 256, 3, 1, 1]),
+                ('conv3_2', [256, 256, 3, 1, 1]),
+                ('conv3_3', [256, 256, 3, 1, 1]),
+                ('conv3_4', [256, 256, 3, 1, 1]),
+                ('pool3_stage1', [2, 2, 0]),
+                ('conv4_1', [256, 512, 3, 1, 1]),
+                ('conv4_2', [512, 512, 3, 1, 1]),
+                ('conv4_3', [512, 512, 3, 1, 1]),
+                ('conv4_4', [512, 512, 3, 1, 1]),
+                ('conv5_1', [512, 512, 3, 1, 1]),
+                ('conv5_2', [512, 512, 3, 1, 1]),
+                ('conv5_3_CPM', [512, 128, 3, 1, 1])
+            ])
+
+        block1_1 = OrderedDict([
+            ('conv6_1_CPM', [128, 512, 1, 1, 0]),
+            ('conv6_2_CPM', [512, 22, 1, 1, 0])
+        ])
+
+        blocks = {}
+        blocks['block1_0'] = block1_0
+        blocks['block1_1'] = block1_1
+
+        # stage 2-6
+        for i in range(2, 7):
+            blocks['block%d' % i] = OrderedDict([
+                    ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
+                    ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
+                    ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
+                    ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])
+                ])
+
+        for k in blocks.keys():
+            blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+        self.model1_0 = blocks['block1_0']
+        self.model1_1 = blocks['block1_1']
+        self.model2 = blocks['block2']
+        self.model3 = blocks['block3']
+        self.model4 = blocks['block4']
+        self.model5 = blocks['block5']
+        self.model6 = blocks['block6']
+
+    def forward(self, x):
+        out1_0 = self.model1_0(x)
+        out1_1 = self.model1_1(out1_0)
+        concat_stage2 = torch.cat([out1_1, out1_0], 1)
+        out_stage2 = self.model2(concat_stage2)
+        concat_stage3 = torch.cat([out_stage2, out1_0], 1)
+        out_stage3 = self.model3(concat_stage3)
+        concat_stage4 = torch.cat([out_stage3, out1_0], 1)
+        out_stage4 = self.model4(concat_stage4)
+        concat_stage5 = torch.cat([out_stage4, out1_0], 1)
+        out_stage5 = self.model5(concat_stage5)
+        concat_stage6 = torch.cat([out_stage5, out1_0], 1)
+        out_stage6 = self.model6(concat_stage6)
+        return out_stage6
+
+
diff --git a/annotator/openpose/util.py b/annotator/openpose/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f91ae0e65abaf0cbd62d803f56498991141e61b
--- /dev/null
+++ b/annotator/openpose/util.py
@@ -0,0 +1,164 @@
+import math
+import numpy as np
+import matplotlib
+import cv2
+
+
+def padRightDownCorner(img, stride, padValue):
+    h = img.shape[0]
+    w = img.shape[1]
+
+    pad = 4 * [None]
+    pad[0] = 0 # up
+    pad[1] = 0 # left
+    pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+    pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+
+    img_padded = img
+    pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
+    img_padded = np.concatenate((pad_up, img_padded), axis=0)
+    pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
+    img_padded = np.concatenate((pad_left, img_padded), axis=1)
+    pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
+    img_padded = np.concatenate((img_padded, pad_down), axis=0)
+    pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
+    img_padded = np.concatenate((img_padded, pad_right), axis=1)
+
+    return img_padded, pad
+
+# transfer caffe model to pytorch which will match the layer name
+def transfer(model, model_weights):
+    transfered_model_weights = {}
+    for weights_name in model.state_dict().keys():
+        transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+    return transfered_model_weights
+
+# draw the body keypoint and lims
+def draw_bodypose(canvas, candidate, subset):
+    stickwidth = 4
+    limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+               [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+               [1, 16], [16, 18], [3, 17], [6, 18]]
+
+    colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+              [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+              [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+    for i in range(18):
+        for n in range(len(subset)):
+            index = int(subset[n][i])
+            if index == -1:
+                continue
+            x, y = candidate[index][0:2]
+            cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
+    for i in range(17):
+        for n in range(len(subset)):
+            index = subset[n][np.array(limbSeq[i]) - 1]
+            if -1 in index:
+                continue
+            cur_canvas = canvas.copy()
+            Y = candidate[index.astype(int), 0]
+            X = candidate[index.astype(int), 1]
+            mX = np.mean(X)
+            mY = np.mean(Y)
+            length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
+            angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+            polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+            cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
+            canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+    # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
+    # plt.imshow(canvas[:, :, [2, 1, 0]])
+    return canvas
+
+
+# image drawed by opencv is not good.
+def draw_handpose(canvas, all_hand_peaks, show_number=False):
+    edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+             [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+
+    for peaks in all_hand_peaks:
+        for ie, e in enumerate(edges):
+            if np.sum(np.all(peaks[e], axis=1)==0)==0:
+                x1, y1 = peaks[e[0]]
+                x2, y2 = peaks[e[1]]
+                cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie/float(len(edges)), 1.0, 1.0])*255, thickness=2)
+
+        for i, keyponit in enumerate(peaks):
+            x, y = keyponit
+            cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+            if show_number:
+                cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA)
+    return canvas
+
+# detect hand according to body pose keypoints
+# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
+def handDetect(candidate, subset, oriImg):
+    # right hand: wrist 4, elbow 3, shoulder 2
+    # left hand: wrist 7, elbow 6, shoulder 5
+    ratioWristElbow = 0.33
+    detect_result = []
+    image_height, image_width = oriImg.shape[0:2]
+    for person in subset.astype(int):
+        # if any of three not detected
+        has_left = np.sum(person[[5, 6, 7]] == -1) == 0
+        has_right = np.sum(person[[2, 3, 4]] == -1) == 0
+        if not (has_left or has_right):
+            continue
+        hands = []
+        #left hand
+        if has_left:
+            left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
+            x1, y1 = candidate[left_shoulder_index][:2]
+            x2, y2 = candidate[left_elbow_index][:2]
+            x3, y3 = candidate[left_wrist_index][:2]
+            hands.append([x1, y1, x2, y2, x3, y3, True])
+        # right hand
+        if has_right:
+            right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
+            x1, y1 = candidate[right_shoulder_index][:2]
+            x2, y2 = candidate[right_elbow_index][:2]
+            x3, y3 = candidate[right_wrist_index][:2]
+            hands.append([x1, y1, x2, y2, x3, y3, False])
+
+        for x1, y1, x2, y2, x3, y3, is_left in hands:
+            # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
+            # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
+            # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
+            # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
+            # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
+            # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
+            x = x3 + ratioWristElbow * (x3 - x2)
+            y = y3 + ratioWristElbow * (y3 - y2)
+            distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
+            distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
+            width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
+            # x-y refers to the center --> offset to topLeft point
+            # handRectangle.x -= handRectangle.width / 2.f;
+            # handRectangle.y -= handRectangle.height / 2.f;
+            x -= width / 2
+            y -= width / 2  # width = height
+            # overflow the image
+            if x < 0: x = 0
+            if y < 0: y = 0
+            width1 = width
+            width2 = width
+            if x + width > image_width: width1 = image_width - x
+            if y + width > image_height: width2 = image_height - y
+            width = min(width1, width2)
+            # the max hand box value is 20 pixels
+            if width >= 20:
+                detect_result.append([int(x), int(y), int(width), is_left])
+
+    '''
+    return value: [[x, y, w, True if left hand else False]].
+    width=height since the network require squared input.
+    x, y is the coordinate of top left 
+    '''
+    return detect_result
+
+# get max index of 2d array
+def npmax(array):
+    arrayindex = array.argmax(1)
+    arrayvalue = array.max(1)
+    i = arrayvalue.argmax()
+    j = arrayindex[i]
+    return i, j
diff --git a/annotator/uniformer/__init__.py b/annotator/uniformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6be429542e4908c2b7648e7ee7c9c5f8253e7c94
--- /dev/null
+++ b/annotator/uniformer/__init__.py
@@ -0,0 +1,23 @@
+import os
+
+from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot
+from annotator.uniformer.mmseg.core.evaluation import get_palette
+from annotator.util import annotator_ckpts_path
+
+
+checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth"
+
+
+class UniformerDetector:
+    def __init__(self):
+        modelpath = os.path.join(annotator_ckpts_path, "upernet_global_small.pth")
+        if not os.path.exists(modelpath):
+            from basicsr.utils.download_util import load_file_from_url
+            load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path)
+        config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py")
+        self.model = init_segmentor(config_file, modelpath).cuda()
+
+    def __call__(self, img):
+        result = inference_segmentor(self.model, img)
+        res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1)
+        return res_img
diff --git a/annotator/uniformer/configs/_base_/datasets/ade20k.py b/annotator/uniformer/configs/_base_/datasets/ade20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..efc8b4bb20c981f3db6df7eb52b3dc0744c94cc0
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/ade20k.py
@@ -0,0 +1,54 @@
+# dataset settings
+dataset_type = 'ADE20KDataset'
+data_root = 'data/ade/ADEChallengeData2016'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', reduce_zero_label=True),
+    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 512),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    samples_per_gpu=4,
+    workers_per_gpu=4,
+    train=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/training',
+        ann_dir='annotations/training',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/chase_db1.py b/annotator/uniformer/configs/_base_/datasets/chase_db1.py
new file mode 100644
index 0000000000000000000000000000000000000000..298594ea925f87f22b37094a2ec50e370aec96a0
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/chase_db1.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'ChaseDB1Dataset'
+data_root = 'data/CHASE_DB1'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (960, 999)
+crop_size = (128, 128)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations'),
+    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=img_scale,
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img'])
+        ])
+]
+
+data = dict(
+    samples_per_gpu=4,
+    workers_per_gpu=4,
+    train=dict(
+        type='RepeatDataset',
+        times=40000,
+        dataset=dict(
+            type=dataset_type,
+            data_root=data_root,
+            img_dir='images/training',
+            ann_dir='annotations/training',
+            pipeline=train_pipeline)),
+    val=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/cityscapes.py b/annotator/uniformer/configs/_base_/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..f21867c63e1835f6fceb61f066e802fd8fd2a735
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/cityscapes.py
@@ -0,0 +1,54 @@
+# dataset settings
+dataset_type = 'CityscapesDataset'
+data_root = 'data/cityscapes/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 1024)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations'),
+    dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 1024),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    samples_per_gpu=2,
+    workers_per_gpu=2,
+    train=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='leftImg8bit/train',
+        ann_dir='gtFine/train',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='leftImg8bit/val',
+        ann_dir='gtFine/val',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='leftImg8bit/val',
+        ann_dir='gtFine/val',
+        pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py b/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py
new file mode 100644
index 0000000000000000000000000000000000000000..336c7b254fe392b4703039fec86a83acdbd2e1a5
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/cityscapes_769x769.py
@@ -0,0 +1,35 @@
+_base_ = './cityscapes.py'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (769, 769)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations'),
+    dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2049, 1025),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    train=dict(pipeline=train_pipeline),
+    val=dict(pipeline=test_pipeline),
+    test=dict(pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/drive.py b/annotator/uniformer/configs/_base_/datasets/drive.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e8ff606e0d2a4514ec8b7d2c6c436a32efcbf4
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/drive.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'DRIVEDataset'
+data_root = 'data/DRIVE'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (584, 565)
+crop_size = (64, 64)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations'),
+    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=img_scale,
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img'])
+        ])
+]
+
+data = dict(
+    samples_per_gpu=4,
+    workers_per_gpu=4,
+    train=dict(
+        type='RepeatDataset',
+        times=40000,
+        dataset=dict(
+            type=dataset_type,
+            data_root=data_root,
+            img_dir='images/training',
+            ann_dir='annotations/training',
+            pipeline=train_pipeline)),
+    val=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/hrf.py b/annotator/uniformer/configs/_base_/datasets/hrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..242d790eb1b83e75cf6b7eaa7a35c674099311ad
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/hrf.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'HRFDataset'
+data_root = 'data/HRF'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (2336, 3504)
+crop_size = (256, 256)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations'),
+    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=img_scale,
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img'])
+        ])
+]
+
+data = dict(
+    samples_per_gpu=4,
+    workers_per_gpu=4,
+    train=dict(
+        type='RepeatDataset',
+        times=40000,
+        dataset=dict(
+            type=dataset_type,
+            data_root=data_root,
+            img_dir='images/training',
+            ann_dir='annotations/training',
+            pipeline=train_pipeline)),
+    val=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_context.py b/annotator/uniformer/configs/_base_/datasets/pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff65bad1b86d7e3a5980bb5b9fc55798dc8df5f4
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/pascal_context.py
@@ -0,0 +1,60 @@
+# dataset settings
+dataset_type = 'PascalContextDataset'
+data_root = 'data/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+img_scale = (520, 520)
+crop_size = (480, 480)
+
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations'),
+    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=img_scale,
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    samples_per_gpu=4,
+    workers_per_gpu=4,
+    train=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClassContext',
+        split='ImageSets/SegmentationContext/train.txt',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClassContext',
+        split='ImageSets/SegmentationContext/val.txt',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClassContext',
+        split='ImageSets/SegmentationContext/val.txt',
+        pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py b/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py
new file mode 100644
index 0000000000000000000000000000000000000000..37585abab89834b95cd5bdd993b994fca1db65f6
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/pascal_context_59.py
@@ -0,0 +1,60 @@
+# dataset settings
+dataset_type = 'PascalContextDataset59'
+data_root = 'data/VOCdevkit/VOC2010/'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+img_scale = (520, 520)
+crop_size = (480, 480)
+
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations', reduce_zero_label=True),
+    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=img_scale,
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    samples_per_gpu=4,
+    workers_per_gpu=4,
+    train=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClassContext',
+        split='ImageSets/SegmentationContext/train.txt',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClassContext',
+        split='ImageSets/SegmentationContext/val.txt',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClassContext',
+        split='ImageSets/SegmentationContext/val.txt',
+        pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py b/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba1d42d0c5781f56dc177d860d856bb34adce555
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/pascal_voc12.py
@@ -0,0 +1,57 @@
+# dataset settings
+dataset_type = 'PascalVOCDataset'
+data_root = 'data/VOCdevkit/VOC2012'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+crop_size = (512, 512)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations'),
+    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=(2048, 512),
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ])
+]
+data = dict(
+    samples_per_gpu=4,
+    workers_per_gpu=4,
+    train=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClass',
+        split='ImageSets/Segmentation/train.txt',
+        pipeline=train_pipeline),
+    val=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClass',
+        split='ImageSets/Segmentation/val.txt',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='JPEGImages',
+        ann_dir='SegmentationClass',
+        split='ImageSets/Segmentation/val.txt',
+        pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py b/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f23b6717d53ad29f02dd15046802a2631a5076b
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/pascal_voc12_aug.py
@@ -0,0 +1,9 @@
+_base_ = './pascal_voc12.py'
+# dataset settings
+data = dict(
+    train=dict(
+        ann_dir=['SegmentationClass', 'SegmentationClassAug'],
+        split=[
+            'ImageSets/Segmentation/train.txt',
+            'ImageSets/Segmentation/aug.txt'
+        ]))
diff --git a/annotator/uniformer/configs/_base_/datasets/stare.py b/annotator/uniformer/configs/_base_/datasets/stare.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f71b25488cc11a6b4d582ac52b5a24e1ad1cf8e
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/datasets/stare.py
@@ -0,0 +1,59 @@
+# dataset settings
+dataset_type = 'STAREDataset'
+data_root = 'data/STARE'
+img_norm_cfg = dict(
+    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+img_scale = (605, 700)
+crop_size = (128, 128)
+train_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(type='LoadAnnotations'),
+    dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
+    dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
+    dict(type='RandomFlip', prob=0.5),
+    dict(type='PhotoMetricDistortion'),
+    dict(type='Normalize', **img_norm_cfg),
+    dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
+    dict(type='DefaultFormatBundle'),
+    dict(type='Collect', keys=['img', 'gt_semantic_seg'])
+]
+test_pipeline = [
+    dict(type='LoadImageFromFile'),
+    dict(
+        type='MultiScaleFlipAug',
+        img_scale=img_scale,
+        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
+        flip=False,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img'])
+        ])
+]
+
+data = dict(
+    samples_per_gpu=4,
+    workers_per_gpu=4,
+    train=dict(
+        type='RepeatDataset',
+        times=40000,
+        dataset=dict(
+            type=dataset_type,
+            data_root=data_root,
+            img_dir='images/training',
+            ann_dir='annotations/training',
+            pipeline=train_pipeline)),
+    val=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline),
+    test=dict(
+        type=dataset_type,
+        data_root=data_root,
+        img_dir='images/validation',
+        ann_dir='annotations/validation',
+        pipeline=test_pipeline))
diff --git a/annotator/uniformer/configs/_base_/default_runtime.py b/annotator/uniformer/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..b564cc4e7e7d9a67dacaaddecb100e4d8f5c005b
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/default_runtime.py
@@ -0,0 +1,14 @@
+# yapf:disable
+log_config = dict(
+    interval=50,
+    hooks=[
+        dict(type='TextLoggerHook', by_epoch=False),
+        # dict(type='TensorboardLoggerHook')
+    ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+cudnn_benchmark = True
diff --git a/annotator/uniformer/configs/_base_/models/ann_r50-d8.py b/annotator/uniformer/configs/_base_/models/ann_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2cb653827e44e6015b3b83bc578003e614a6aa1
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/ann_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='ANNHead',
+        in_channels=[1024, 2048],
+        in_index=[2, 3],
+        channels=512,
+        project_channels=256,
+        query_scales=(1, ),
+        key_pool_scales=(1, 3, 6, 8),
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8f5316cbcf3896ba9de7ca2c801eba512f01d5e
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/apcnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='APCHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        pool_scales=(1, 2, 3, 6),
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=dict(type='SyncBN', requires_grad=True),
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..794148f576b9e215c3c6963e73dffe98204b7717
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/ccnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='CCHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        recurrence=2,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/cgnet.py b/annotator/uniformer/configs/_base_/models/cgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..eff8d9458c877c5db894957e0b1b4597e40da6ab
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/cgnet.py
@@ -0,0 +1,35 @@
+# model settings
+norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    backbone=dict(
+        type='CGNet',
+        norm_cfg=norm_cfg,
+        in_channels=3,
+        num_channels=(32, 64, 128),
+        num_blocks=(3, 21),
+        dilations=(2, 4),
+        reductions=(8, 16)),
+    decode_head=dict(
+        type='FCNHead',
+        in_channels=256,
+        in_index=2,
+        channels=256,
+        num_convs=0,
+        concat_input=False,
+        dropout_ratio=0,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        loss_decode=dict(
+            type='CrossEntropyLoss',
+            use_sigmoid=False,
+            loss_weight=1.0,
+            class_weight=[
+                2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352,
+                10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905,
+                10.347791, 6.3927646, 10.226669, 10.241062, 10.280587,
+                10.396974, 10.055647
+            ])),
+    # model training and testing settings
+    train_cfg=dict(sampler=None),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/danet_r50-d8.py b/annotator/uniformer/configs/_base_/models/danet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c934939fac48525f22ad86f489a041dd7db7d09
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/danet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='DAHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        pam_channels=64,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py b/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7a43bee01422ad4795dd27874e0cd4bb6cbfecf
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/deeplabv3_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='ASPPHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        dilations=(1, 12, 24, 36),
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py b/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cd262999d8b2cb8e14a5c32190ae73f479d8e81
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py
@@ -0,0 +1,50 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained=None,
+    backbone=dict(
+        type='UNet',
+        in_channels=3,
+        base_channels=64,
+        num_stages=5,
+        strides=(1, 1, 1, 1, 1),
+        enc_num_convs=(2, 2, 2, 2, 2),
+        dec_num_convs=(2, 2, 2, 2),
+        downsamples=(True, True, True, True),
+        enc_dilations=(1, 1, 1, 1, 1),
+        dec_dilations=(1, 1, 1, 1),
+        with_cp=False,
+        conv_cfg=None,
+        norm_cfg=norm_cfg,
+        act_cfg=dict(type='ReLU'),
+        upsample_cfg=dict(type='InterpConv'),
+        norm_eval=False),
+    decode_head=dict(
+        type='ASPPHead',
+        in_channels=64,
+        in_index=4,
+        channels=16,
+        dilations=(1, 12, 24, 36),
+        dropout_ratio=0.1,
+        num_classes=2,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=128,
+        in_index=3,
+        channels=64,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=2,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py b/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..050e39e091d816df9028d23aa3ecf9db74e441e1
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='DepthwiseSeparableASPPHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        dilations=(1, 12, 24, 36),
+        c1_in_channels=256,
+        c1_channels=48,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..d22ba52640bebd805b3b8d07025e276dfb023759
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/dmnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='DMHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        filter_sizes=(1, 3, 5, 7),
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=dict(type='SyncBN', requires_grad=True),
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py b/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..edb4c174c51e34c103737ba39bfc48bf831e561d
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/dnl_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='DNLHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        dropout_ratio=0.1,
+        reduction=2,
+        use_scale=True,
+        mode='embedded_gaussian',
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py b/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..26adcd430926de0862204a71d345f2543167f27b
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/emanet_r50-d8.py
@@ -0,0 +1,47 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='EMAHead',
+        in_channels=2048,
+        in_index=3,
+        channels=256,
+        ema_channels=512,
+        num_bases=64,
+        num_stages=3,
+        momentum=0.1,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..be777123a886503172a95fe0719e956a147bbd68
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/encnet_r50-d8.py
@@ -0,0 +1,48 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='EncHead',
+        in_channels=[512, 1024, 2048],
+        in_index=(1, 2, 3),
+        channels=512,
+        num_codes=32,
+        use_se_loss=True,
+        add_lateral=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+        loss_se_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fast_scnn.py b/annotator/uniformer/configs/_base_/models/fast_scnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..32fdeb659355a5ce5ef2cc7c2f30742703811cdf
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fast_scnn.py
@@ -0,0 +1,57 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01)
+model = dict(
+    type='EncoderDecoder',
+    backbone=dict(
+        type='FastSCNN',
+        downsample_dw_channels=(32, 48),
+        global_in_channels=64,
+        global_block_channels=(64, 96, 128),
+        global_block_strides=(2, 2, 1),
+        global_out_channels=128,
+        higher_in_channels=64,
+        lower_in_channels=128,
+        fusion_out_channels=128,
+        out_indices=(0, 1, 2),
+        norm_cfg=norm_cfg,
+        align_corners=False),
+    decode_head=dict(
+        type='DepthwiseSeparableFCNHead',
+        in_channels=128,
+        channels=128,
+        concat_input=False,
+        num_classes=19,
+        in_index=-1,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+    auxiliary_head=[
+        dict(
+            type='FCNHead',
+            in_channels=128,
+            channels=32,
+            num_convs=1,
+            num_classes=19,
+            in_index=-2,
+            norm_cfg=norm_cfg,
+            concat_input=False,
+            align_corners=False,
+            loss_decode=dict(
+                type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+        dict(
+            type='FCNHead',
+            in_channels=64,
+            channels=32,
+            num_convs=1,
+            num_classes=19,
+            in_index=-3,
+            norm_cfg=norm_cfg,
+            concat_input=False,
+            align_corners=False,
+            loss_decode=dict(
+                type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)),
+    ],
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fcn_hr18.py b/annotator/uniformer/configs/_base_/models/fcn_hr18.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e299bc89ada56ca14bbffcbdb08a586b8ed9e9
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fcn_hr18.py
@@ -0,0 +1,52 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://msra/hrnetv2_w18',
+    backbone=dict(
+        type='HRNet',
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        extra=dict(
+            stage1=dict(
+                num_modules=1,
+                num_branches=1,
+                block='BOTTLENECK',
+                num_blocks=(4, ),
+                num_channels=(64, )),
+            stage2=dict(
+                num_modules=1,
+                num_branches=2,
+                block='BASIC',
+                num_blocks=(4, 4),
+                num_channels=(18, 36)),
+            stage3=dict(
+                num_modules=4,
+                num_branches=3,
+                block='BASIC',
+                num_blocks=(4, 4, 4),
+                num_channels=(18, 36, 72)),
+            stage4=dict(
+                num_modules=3,
+                num_branches=4,
+                block='BASIC',
+                num_blocks=(4, 4, 4, 4),
+                num_channels=(18, 36, 72, 144)))),
+    decode_head=dict(
+        type='FCNHead',
+        in_channels=[18, 36, 72, 144],
+        in_index=(0, 1, 2, 3),
+        channels=sum([18, 36, 72, 144]),
+        input_transform='resize_concat',
+        kernel_size=1,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=-1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py b/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e98f6cc918b6146fc6d613c6918e825ef1355c3
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fcn_r50-d8.py
@@ -0,0 +1,45 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='FCNHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        num_convs=2,
+        concat_input=True,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py b/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33e7972877f902d0e7d18401ca675e3e4e60a18
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fcn_unet_s5-d16.py
@@ -0,0 +1,51 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained=None,
+    backbone=dict(
+        type='UNet',
+        in_channels=3,
+        base_channels=64,
+        num_stages=5,
+        strides=(1, 1, 1, 1, 1),
+        enc_num_convs=(2, 2, 2, 2, 2),
+        dec_num_convs=(2, 2, 2, 2),
+        downsamples=(True, True, True, True),
+        enc_dilations=(1, 1, 1, 1, 1),
+        dec_dilations=(1, 1, 1, 1),
+        with_cp=False,
+        conv_cfg=None,
+        norm_cfg=norm_cfg,
+        act_cfg=dict(type='ReLU'),
+        upsample_cfg=dict(type='InterpConv'),
+        norm_eval=False),
+    decode_head=dict(
+        type='FCNHead',
+        in_channels=64,
+        in_index=4,
+        channels=64,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=2,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=128,
+        in_index=3,
+        channels=64,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=2,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/annotator/uniformer/configs/_base_/models/fpn_r50.py b/annotator/uniformer/configs/_base_/models/fpn_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..86ab327db92e44c14822d65f1c9277cb007f17c1
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fpn_r50.py
@@ -0,0 +1,36 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 1, 1),
+        strides=(1, 2, 2, 2),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=4),
+    decode_head=dict(
+        type='FPNHead',
+        in_channels=[256, 256, 256, 256],
+        in_index=[0, 1, 2, 3],
+        feature_strides=[4, 8, 16, 32],
+        channels=128,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/fpn_uniformer.py b/annotator/uniformer/configs/_base_/models/fpn_uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aae98c5991055bfcc08e82ccdc09f8b1d9f8a8d
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/fpn_uniformer.py
@@ -0,0 +1,35 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    backbone=dict(
+        type='UniFormer',
+        embed_dim=[64, 128, 320, 512],
+        layers=[3, 4, 8, 3],
+        head_dim=64,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        drop_rate=0.,
+        attn_drop_rate=0.,
+        drop_path_rate=0.1),
+    neck=dict(
+        type='FPN',
+        in_channels=[64, 128, 320, 512],
+        out_channels=256,
+        num_outs=4),
+    decode_head=dict(
+        type='FPNHead',
+        in_channels=[256, 256, 256, 256],
+        in_index=[0, 1, 2, 3],
+        feature_strides=[4, 8, 16, 32],
+        channels=128,
+        dropout_ratio=0.1,
+        num_classes=150,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole')
+)
diff --git a/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d2ad69f5c22adfe79d5fdabf920217628987166
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/gcnet_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='GCHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        ratio=1 / 4.,
+        pooling_type='att',
+        fusion_types=('channel_add', ),
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py b/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..93258242a90695cc94a7c6bd41562d6a75988771
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/lraspp_m-v3-d8.py
@@ -0,0 +1,25 @@
+# model settings
+norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    backbone=dict(
+        type='MobileNetV3',
+        arch='large',
+        out_indices=(1, 3, 16),
+        norm_cfg=norm_cfg),
+    decode_head=dict(
+        type='LRASPPHead',
+        in_channels=(16, 24, 960),
+        in_index=(0, 1, 2),
+        channels=128,
+        input_transform='multiple_select',
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        act_cfg=dict(type='ReLU'),
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py b/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..5674a39854cafd1f2e363bac99c58ccae62f24da
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/nonlocal_r50-d8.py
@@ -0,0 +1,46 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='NLHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        dropout_ratio=0.1,
+        reduction=2,
+        use_scale=True,
+        mode='embedded_gaussian',
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py b/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py
new file mode 100644
index 0000000000000000000000000000000000000000..c60f62a7cdf3f5c5096a7a7e725e8268fddcb057
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/ocrnet_hr18.py
@@ -0,0 +1,68 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='CascadeEncoderDecoder',
+    num_stages=2,
+    pretrained='open-mmlab://msra/hrnetv2_w18',
+    backbone=dict(
+        type='HRNet',
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        extra=dict(
+            stage1=dict(
+                num_modules=1,
+                num_branches=1,
+                block='BOTTLENECK',
+                num_blocks=(4, ),
+                num_channels=(64, )),
+            stage2=dict(
+                num_modules=1,
+                num_branches=2,
+                block='BASIC',
+                num_blocks=(4, 4),
+                num_channels=(18, 36)),
+            stage3=dict(
+                num_modules=4,
+                num_branches=3,
+                block='BASIC',
+                num_blocks=(4, 4, 4),
+                num_channels=(18, 36, 72)),
+            stage4=dict(
+                num_modules=3,
+                num_branches=4,
+                block='BASIC',
+                num_blocks=(4, 4, 4, 4),
+                num_channels=(18, 36, 72, 144)))),
+    decode_head=[
+        dict(
+            type='FCNHead',
+            in_channels=[18, 36, 72, 144],
+            channels=sum([18, 36, 72, 144]),
+            in_index=(0, 1, 2, 3),
+            input_transform='resize_concat',
+            kernel_size=1,
+            num_convs=1,
+            concat_input=False,
+            dropout_ratio=-1,
+            num_classes=19,
+            norm_cfg=norm_cfg,
+            align_corners=False,
+            loss_decode=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+        dict(
+            type='OCRHead',
+            in_channels=[18, 36, 72, 144],
+            in_index=(0, 1, 2, 3),
+            input_transform='resize_concat',
+            channels=512,
+            ocr_channels=256,
+            dropout_ratio=-1,
+            num_classes=19,
+            norm_cfg=norm_cfg,
+            align_corners=False,
+            loss_decode=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    ],
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..615aa3ff703942b6c22b2d6e9642504dd3e41ebd
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/ocrnet_r50-d8.py
@@ -0,0 +1,47 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='CascadeEncoderDecoder',
+    num_stages=2,
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=[
+        dict(
+            type='FCNHead',
+            in_channels=1024,
+            in_index=2,
+            channels=256,
+            num_convs=1,
+            concat_input=False,
+            dropout_ratio=0.1,
+            num_classes=19,
+            norm_cfg=norm_cfg,
+            align_corners=False,
+            loss_decode=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+        dict(
+            type='OCRHead',
+            in_channels=2048,
+            in_index=3,
+            channels=512,
+            ocr_channels=256,
+            dropout_ratio=0.1,
+            num_classes=19,
+            norm_cfg=norm_cfg,
+            align_corners=False,
+            loss_decode=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
+    ],
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/pointrend_r50.py b/annotator/uniformer/configs/_base_/models/pointrend_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d323dbf9466d41e0800aa57ef84045f3d874bdf
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/pointrend_r50.py
@@ -0,0 +1,56 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='CascadeEncoderDecoder',
+    num_stages=2,
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 1, 1),
+        strides=(1, 2, 2, 2),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    neck=dict(
+        type='FPN',
+        in_channels=[256, 512, 1024, 2048],
+        out_channels=256,
+        num_outs=4),
+    decode_head=[
+        dict(
+            type='FPNHead',
+            in_channels=[256, 256, 256, 256],
+            in_index=[0, 1, 2, 3],
+            feature_strides=[4, 8, 16, 32],
+            channels=128,
+            dropout_ratio=-1,
+            num_classes=19,
+            norm_cfg=norm_cfg,
+            align_corners=False,
+            loss_decode=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+        dict(
+            type='PointHead',
+            in_channels=[256],
+            in_index=[0],
+            channels=256,
+            num_fcs=3,
+            coarse_pred_each_layer=True,
+            dropout_ratio=-1,
+            num_classes=19,
+            align_corners=False,
+            loss_decode=dict(
+                type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
+    ],
+    # model training and testing settings
+    train_cfg=dict(
+        num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75),
+    test_cfg=dict(
+        mode='whole',
+        subdivision_steps=2,
+        subdivision_num_points=8196,
+        scale_factor=2))
diff --git a/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py b/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..689513fa9d2a40f14bf0ae4ae61f38f0dcc1b3da
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/psanet_r50-d8.py
@@ -0,0 +1,49 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='PSAHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        mask_size=(97, 97),
+        psa_type='bi-direction',
+        compact=False,
+        shrink_factor=2,
+        normalization_factor=1.0,
+        psa_softmax=True,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py b/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py
new file mode 100644
index 0000000000000000000000000000000000000000..f451e08ad2eb0732dcb806b1851eb978d4acf136
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/pspnet_r50-d8.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 2, 4),
+        strides=(1, 2, 1, 1),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='PSPHead',
+        in_channels=2048,
+        in_index=3,
+        channels=512,
+        pool_scales=(1, 2, 3, 6),
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py b/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcff9ec4f41fad158344ecd77313dc14564f3682
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py
@@ -0,0 +1,50 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained=None,
+    backbone=dict(
+        type='UNet',
+        in_channels=3,
+        base_channels=64,
+        num_stages=5,
+        strides=(1, 1, 1, 1, 1),
+        enc_num_convs=(2, 2, 2, 2, 2),
+        dec_num_convs=(2, 2, 2, 2),
+        downsamples=(True, True, True, True),
+        enc_dilations=(1, 1, 1, 1, 1),
+        dec_dilations=(1, 1, 1, 1),
+        with_cp=False,
+        conv_cfg=None,
+        norm_cfg=norm_cfg,
+        act_cfg=dict(type='ReLU'),
+        upsample_cfg=dict(type='InterpConv'),
+        norm_eval=False),
+    decode_head=dict(
+        type='PSPHead',
+        in_channels=64,
+        in_index=4,
+        channels=16,
+        pool_scales=(1, 2, 3, 6),
+        dropout_ratio=0.1,
+        num_classes=2,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=128,
+        in_index=3,
+        channels=64,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=2,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='slide', crop_size=256, stride=170))
diff --git a/annotator/uniformer/configs/_base_/models/upernet_r50.py b/annotator/uniformer/configs/_base_/models/upernet_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..10974962fdd7136031fd06de1700f497d355ceaa
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/upernet_r50.py
@@ -0,0 +1,44 @@
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained='open-mmlab://resnet50_v1c',
+    backbone=dict(
+        type='ResNetV1c',
+        depth=50,
+        num_stages=4,
+        out_indices=(0, 1, 2, 3),
+        dilations=(1, 1, 1, 1),
+        strides=(1, 2, 2, 2),
+        norm_cfg=norm_cfg,
+        norm_eval=False,
+        style='pytorch',
+        contract_dilation=True),
+    decode_head=dict(
+        type='UPerHead',
+        in_channels=[256, 512, 1024, 2048],
+        in_index=[0, 1, 2, 3],
+        pool_scales=(1, 2, 3, 6),
+        channels=512,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=1024,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
diff --git a/annotator/uniformer/configs/_base_/models/upernet_uniformer.py b/annotator/uniformer/configs/_base_/models/upernet_uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..41aa4db809dc6e2c508e98051f61807d07477903
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/models/upernet_uniformer.py
@@ -0,0 +1,43 @@
+# model settings
+norm_cfg = dict(type='BN', requires_grad=True)
+model = dict(
+    type='EncoderDecoder',
+    pretrained=None,
+    backbone=dict(
+        type='UniFormer',
+        embed_dim=[64, 128, 320, 512],
+        layers=[3, 4, 8, 3],
+        head_dim=64,
+        mlp_ratio=4.,
+        qkv_bias=True,
+        drop_rate=0.,
+        attn_drop_rate=0.,
+        drop_path_rate=0.1),
+    decode_head=dict(
+        type='UPerHead',
+        in_channels=[64, 128, 320, 512],
+        in_index=[0, 1, 2, 3],
+        pool_scales=(1, 2, 3, 6),
+        channels=512,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+    auxiliary_head=dict(
+        type='FCNHead',
+        in_channels=320,
+        in_index=2,
+        channels=256,
+        num_convs=1,
+        concat_input=False,
+        dropout_ratio=0.1,
+        num_classes=19,
+        norm_cfg=norm_cfg,
+        align_corners=False,
+        loss_decode=dict(
+            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+    # model training and testing settings
+    train_cfg=dict(),
+    test_cfg=dict(mode='whole'))
\ No newline at end of file
diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_160k.py b/annotator/uniformer/configs/_base_/schedules/schedule_160k.py
new file mode 100644
index 0000000000000000000000000000000000000000..52603890b10f25faf8eec9f9e5a4468fae09b811
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/schedules/schedule_160k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=160000)
+checkpoint_config = dict(by_epoch=False, interval=16000)
+evaluation = dict(interval=16000, metric='mIoU')
diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_20k.py b/annotator/uniformer/configs/_base_/schedules/schedule_20k.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf780a1b6f6521833c6a5859675147824efa599d
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/schedules/schedule_20k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=20000)
+checkpoint_config = dict(by_epoch=False, interval=2000)
+evaluation = dict(interval=2000, metric='mIoU')
diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_40k.py b/annotator/uniformer/configs/_base_/schedules/schedule_40k.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdbf841abcb26eed87bf76ab816aff4bae0630ee
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/schedules/schedule_40k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=40000)
+checkpoint_config = dict(by_epoch=False, interval=4000)
+evaluation = dict(interval=4000, metric='mIoU')
diff --git a/annotator/uniformer/configs/_base_/schedules/schedule_80k.py b/annotator/uniformer/configs/_base_/schedules/schedule_80k.py
new file mode 100644
index 0000000000000000000000000000000000000000..c190cee6bdc7922b688ea75dc8f152fa15c24617
--- /dev/null
+++ b/annotator/uniformer/configs/_base_/schedules/schedule_80k.py
@@ -0,0 +1,9 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict()
+# learning policy
+lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+# runtime settings
+runner = dict(type='IterBasedRunner', max_iters=80000)
+checkpoint_config = dict(by_epoch=False, interval=8000)
+evaluation = dict(interval=8000, metric='mIoU')
diff --git a/annotator/uniformer/exp/upernet_global_small/config.py b/annotator/uniformer/exp/upernet_global_small/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..01db96bf9b0be531aa0eaf62fee51543712f8670
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/config.py
@@ -0,0 +1,38 @@
+_base_ = [
+    '../../configs/_base_/models/upernet_uniformer.py', 
+    '../../configs/_base_/datasets/ade20k.py',
+    '../../configs/_base_/default_runtime.py', 
+    '../../configs/_base_/schedules/schedule_160k.py'
+]
+model = dict(
+    backbone=dict(
+        type='UniFormer',
+        embed_dim=[64, 128, 320, 512],
+        layers=[3, 4, 8, 3],
+        head_dim=64,
+        drop_path_rate=0.25,
+        windows=False,
+        hybrid=False
+    ),
+    decode_head=dict(
+        in_channels=[64, 128, 320, 512],
+        num_classes=150
+    ),
+    auxiliary_head=dict(
+        in_channels=320,
+        num_classes=150
+    ))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.)}))
+
+lr_config = dict(_delete_=True, policy='poly',
+                 warmup='linear',
+                 warmup_iters=1500,
+                 warmup_ratio=1e-6,
+                 power=1.0, min_lr=0.0, by_epoch=False)
+
+data=dict(samples_per_gpu=2)
\ No newline at end of file
diff --git a/annotator/uniformer/exp/upernet_global_small/run.sh b/annotator/uniformer/exp/upernet_global_small/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9fb22edfa7a32624ea08a63fe7d720c40db3b696
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/run.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+
+work_path=$(dirname $0)
+PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=8 \
+    tools/train.py ${work_path}/config.py \
+    --launcher pytorch \
+    --options model.backbone.pretrained_path='your_model_path/uniformer_small_in1k.pth' \
+    --work-dir ${work_path}/ckpt \
+    2>&1 | tee -a ${work_path}/log.txt
diff --git a/annotator/uniformer/exp/upernet_global_small/test.sh b/annotator/uniformer/exp/upernet_global_small/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d9a85e7a0d3b7c96b060f473d41254b37a382fcb
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/test.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+
+work_path=$(dirname $0)
+PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
+python -m torch.distributed.launch --nproc_per_node=8 \
+    tools/test.py ${work_path}/test_config_h32.py \
+    ${work_path}/ckpt/latest.pth \
+    --launcher pytorch \
+    --eval mIoU \
+    2>&1 | tee -a ${work_path}/log.txt
diff --git a/annotator/uniformer/exp/upernet_global_small/test_config_g.py b/annotator/uniformer/exp/upernet_global_small/test_config_g.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43737a98a3b174a9f2fe059c06d511144686459
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/test_config_g.py
@@ -0,0 +1,38 @@
+_base_ = [
+    '../../configs/_base_/models/upernet_uniformer.py', 
+    '../../configs/_base_/datasets/ade20k.py',
+    '../../configs/_base_/default_runtime.py', 
+    '../../configs/_base_/schedules/schedule_160k.py'
+]
+model = dict(
+    backbone=dict(
+        type='UniFormer',
+        embed_dim=[64, 128, 320, 512],
+        layers=[3, 4, 8, 3],
+        head_dim=64,
+        drop_path_rate=0.25,
+        windows=False,
+        hybrid=False,
+    ),
+    decode_head=dict(
+        in_channels=[64, 128, 320, 512],
+        num_classes=150
+    ),
+    auxiliary_head=dict(
+        in_channels=320,
+        num_classes=150
+    ))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.)}))
+
+lr_config = dict(_delete_=True, policy='poly',
+                 warmup='linear',
+                 warmup_iters=1500,
+                 warmup_ratio=1e-6,
+                 power=1.0, min_lr=0.0, by_epoch=False)
+
+data=dict(samples_per_gpu=2)
\ No newline at end of file
diff --git a/annotator/uniformer/exp/upernet_global_small/test_config_h32.py b/annotator/uniformer/exp/upernet_global_small/test_config_h32.py
new file mode 100644
index 0000000000000000000000000000000000000000..a31e3874f76f9f7b089ac8834d85df2441af9b0e
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/test_config_h32.py
@@ -0,0 +1,39 @@
+_base_ = [
+    '../../configs/_base_/models/upernet_uniformer.py', 
+    '../../configs/_base_/datasets/ade20k.py',
+    '../../configs/_base_/default_runtime.py', 
+    '../../configs/_base_/schedules/schedule_160k.py'
+]
+model = dict(
+    backbone=dict(
+        type='UniFormer',
+        embed_dim=[64, 128, 320, 512],
+        layers=[3, 4, 8, 3],
+        head_dim=64,
+        drop_path_rate=0.25,
+        windows=False,
+        hybrid=True,
+        window_size=32
+    ),
+    decode_head=dict(
+        in_channels=[64, 128, 320, 512],
+        num_classes=150
+    ),
+    auxiliary_head=dict(
+        in_channels=320,
+        num_classes=150
+    ))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.)}))
+
+lr_config = dict(_delete_=True, policy='poly',
+                 warmup='linear',
+                 warmup_iters=1500,
+                 warmup_ratio=1e-6,
+                 power=1.0, min_lr=0.0, by_epoch=False)
+
+data=dict(samples_per_gpu=2)
\ No newline at end of file
diff --git a/annotator/uniformer/exp/upernet_global_small/test_config_w32.py b/annotator/uniformer/exp/upernet_global_small/test_config_w32.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d9e06f029e46c14cb9ddb39319cabe86fef9b44
--- /dev/null
+++ b/annotator/uniformer/exp/upernet_global_small/test_config_w32.py
@@ -0,0 +1,39 @@
+_base_ = [
+    '../../configs/_base_/models/upernet_uniformer.py', 
+    '../../configs/_base_/datasets/ade20k.py',
+    '../../configs/_base_/default_runtime.py', 
+    '../../configs/_base_/schedules/schedule_160k.py'
+]
+model = dict(
+    backbone=dict(
+        type='UniFormer',
+        embed_dim=[64, 128, 320, 512],
+        layers=[3, 4, 8, 3],
+        head_dim=64,
+        drop_path_rate=0.25,
+        windows=True,
+        hybrid=False,
+        window_size=32
+    ),
+    decode_head=dict(
+        in_channels=[64, 128, 320, 512],
+        num_classes=150
+    ),
+    auxiliary_head=dict(
+        in_channels=320,
+        num_classes=150
+    ))
+
+# AdamW optimizer, no weight decay for position embedding & layer norm in backbone
+optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
+                 paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+                                                 'relative_position_bias_table': dict(decay_mult=0.),
+                                                 'norm': dict(decay_mult=0.)}))
+
+lr_config = dict(_delete_=True, policy='poly',
+                 warmup='linear',
+                 warmup_iters=1500,
+                 warmup_ratio=1e-6,
+                 power=1.0, min_lr=0.0, by_epoch=False)
+
+data=dict(samples_per_gpu=2)
\ No newline at end of file
diff --git a/annotator/uniformer/mmcv/__init__.py b/annotator/uniformer/mmcv/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..210a2989138380559f23045b568d0fbbeb918c03
--- /dev/null
+++ b/annotator/uniformer/mmcv/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# flake8: noqa
+from .arraymisc import *
+from .fileio import *
+from .image import *
+from .utils import *
+from .version import *
+from .video import *
+from .visualization import *
+
+# The following modules are not imported to this level, so mmcv may be used
+# without PyTorch.
+# - runner
+# - parallel
+# - op
diff --git a/annotator/uniformer/mmcv/arraymisc/__init__.py b/annotator/uniformer/mmcv/arraymisc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4700d6139ae3d604ff6e542468cce4200c020c
--- /dev/null
+++ b/annotator/uniformer/mmcv/arraymisc/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .quantization import dequantize, quantize
+
+__all__ = ['quantize', 'dequantize']
diff --git a/annotator/uniformer/mmcv/arraymisc/quantization.py b/annotator/uniformer/mmcv/arraymisc/quantization.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e47a3545780cf071a1ef8195efb0b7b662c8186
--- /dev/null
+++ b/annotator/uniformer/mmcv/arraymisc/quantization.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+    """Quantize an array of (-inf, inf) to [0, levels-1].
+
+    Args:
+        arr (ndarray): Input array.
+        min_val (scalar): Minimum value to be clipped.
+        max_val (scalar): Maximum value to be clipped.
+        levels (int): Quantization levels.
+        dtype (np.type): The type of the quantized array.
+
+    Returns:
+        tuple: Quantized array.
+    """
+    if not (isinstance(levels, int) and levels > 1):
+        raise ValueError(
+            f'levels must be a positive integer, but got {levels}')
+    if min_val >= max_val:
+        raise ValueError(
+            f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+    arr = np.clip(arr, min_val, max_val) - min_val
+    quantized_arr = np.minimum(
+        np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+    return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+    """Dequantize an array.
+
+    Args:
+        arr (ndarray): Input array.
+        min_val (scalar): Minimum value to be clipped.
+        max_val (scalar): Maximum value to be clipped.
+        levels (int): Quantization levels.
+        dtype (np.type): The type of the dequantized array.
+
+    Returns:
+        tuple: Dequantized array.
+    """
+    if not (isinstance(levels, int) and levels > 1):
+        raise ValueError(
+            f'levels must be a positive integer, but got {levels}')
+    if min_val >= max_val:
+        raise ValueError(
+            f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+    dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
+                                                   min_val) / levels + min_val
+
+    return dequantized_arr
diff --git a/annotator/uniformer/mmcv/cnn/__init__.py b/annotator/uniformer/mmcv/cnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7246c897430f0cc7ce12719ad8608824fc734446
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/__init__.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .alexnet import AlexNet
+# yapf: disable
+from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+                     PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
+                     ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
+                     ConvTranspose2d, ConvTranspose3d, ConvWS2d,
+                     DepthwiseSeparableConvModule, GeneralizedAttention,
+                     HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
+                     NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
+                     build_activation_layer, build_conv_layer,
+                     build_norm_layer, build_padding_layer, build_plugin_layer,
+                     build_upsample_layer, conv_ws_2d, is_norm)
+from .builder import MODELS, build_model_from_cfg
+# yapf: enable
+from .resnet import ResNet, make_res_layer
+from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
+                    NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
+                    XavierInit, bias_init_with_prob, caffe2_xavier_init,
+                    constant_init, fuse_conv_bn, get_model_complexity_info,
+                    initialize, kaiming_init, normal_init, trunc_normal_init,
+                    uniform_init, xavier_init)
+from .vgg import VGG, make_vgg_layer
+
+__all__ = [
+    'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
+    'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
+    'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
+    'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
+    'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
+    'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
+    'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
+    'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
+    'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
+    'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
+    'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
+    'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
+    'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+    'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+    'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
+]
diff --git a/annotator/uniformer/mmcv/cnn/alexnet.py b/annotator/uniformer/mmcv/cnn/alexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..89e36b8c7851f895d9ae7f07149f0e707456aab0
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/alexnet.py
@@ -0,0 +1,61 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+
+
+class AlexNet(nn.Module):
+    """AlexNet backbone.
+
+    Args:
+        num_classes (int): number of classes for classification.
+    """
+
+    def __init__(self, num_classes=-1):
+        super(AlexNet, self).__init__()
+        self.num_classes = num_classes
+        self.features = nn.Sequential(
+            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(64, 192, kernel_size=5, padding=2),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+            nn.Conv2d(192, 384, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(384, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(256, 256, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.MaxPool2d(kernel_size=3, stride=2),
+        )
+        if self.num_classes > 0:
+            self.classifier = nn.Sequential(
+                nn.Dropout(),
+                nn.Linear(256 * 6 * 6, 4096),
+                nn.ReLU(inplace=True),
+                nn.Dropout(),
+                nn.Linear(4096, 4096),
+                nn.ReLU(inplace=True),
+                nn.Linear(4096, num_classes),
+            )
+
+    def init_weights(self, pretrained=None):
+        if isinstance(pretrained, str):
+            logger = logging.getLogger()
+            from ..runner import load_checkpoint
+            load_checkpoint(self, pretrained, strict=False, logger=logger)
+        elif pretrained is None:
+            # use default initializer
+            pass
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def forward(self, x):
+
+        x = self.features(x)
+        if self.num_classes > 0:
+            x = x.view(x.size(0), 256 * 6 * 6)
+            x = self.classifier(x)
+
+        return x
diff --git a/annotator/uniformer/mmcv/cnn/bricks/__init__.py b/annotator/uniformer/mmcv/cnn/bricks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f33124ed23fc6f27119a37bcb5ab004d3572be0
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .activation import build_activation_layer
+from .context_block import ContextBlock
+from .conv import build_conv_layer
+from .conv2d_adaptive_padding import Conv2dAdaptivePadding
+from .conv_module import ConvModule
+from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
+from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
+from .drop import Dropout, DropPath
+from .generalized_attention import GeneralizedAttention
+from .hsigmoid import HSigmoid
+from .hswish import HSwish
+from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
+from .norm import build_norm_layer, is_norm
+from .padding import build_padding_layer
+from .plugin import build_plugin_layer
+from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
+                       PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
+from .scale import Scale
+from .swish import Swish
+from .upsample import build_upsample_layer
+from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
+                       Linear, MaxPool2d, MaxPool3d)
+
+__all__ = [
+    'ConvModule', 'build_activation_layer', 'build_conv_layer',
+    'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
+    'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
+    'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
+    'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
+    'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
+    'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
+    'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
+    'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
+]
diff --git a/annotator/uniformer/mmcv/cnn/bricks/activation.py b/annotator/uniformer/mmcv/cnn/bricks/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..cab2712287d5ef7be2f079dcb54a94b96394eab5
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/activation.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
+from .registry import ACTIVATION_LAYERS
+
+for module in [
+        nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
+        nn.Sigmoid, nn.Tanh
+]:
+    ACTIVATION_LAYERS.register_module(module=module)
+
+
+@ACTIVATION_LAYERS.register_module(name='Clip')
+@ACTIVATION_LAYERS.register_module()
+class Clamp(nn.Module):
+    """Clamp activation layer.
+
+    This activation function is to clamp the feature map value within
+    :math:`[min, max]`. More details can be found in ``torch.clamp()``.
+
+    Args:
+        min (Number | optional): Lower-bound of the range to be clamped to.
+            Default to -1.
+        max (Number | optional): Upper-bound of the range to be clamped to.
+            Default to 1.
+    """
+
+    def __init__(self, min=-1., max=1.):
+        super(Clamp, self).__init__()
+        self.min = min
+        self.max = max
+
+    def forward(self, x):
+        """Forward function.
+
+        Args:
+            x (torch.Tensor): The input tensor.
+
+        Returns:
+            torch.Tensor: Clamped tensor.
+        """
+        return torch.clamp(x, min=self.min, max=self.max)
+
+
+class GELU(nn.Module):
+    r"""Applies the Gaussian Error Linear Units function:
+
+    .. math::
+        \text{GELU}(x) = x * \Phi(x)
+    where :math:`\Phi(x)` is the Cumulative Distribution Function for
+    Gaussian Distribution.
+
+    Shape:
+        - Input: :math:`(N, *)` where `*` means, any number of additional
+          dimensions
+        - Output: :math:`(N, *)`, same shape as the input
+
+    .. image:: scripts/activation_images/GELU.png
+
+    Examples::
+
+        >>> m = nn.GELU()
+        >>> input = torch.randn(2)
+        >>> output = m(input)
+    """
+
+    def forward(self, input):
+        return F.gelu(input)
+
+
+if (TORCH_VERSION == 'parrots'
+        or digit_version(TORCH_VERSION) < digit_version('1.4')):
+    ACTIVATION_LAYERS.register_module(module=GELU)
+else:
+    ACTIVATION_LAYERS.register_module(module=nn.GELU)
+
+
+def build_activation_layer(cfg):
+    """Build activation layer.
+
+    Args:
+        cfg (dict): The activation layer config, which should contain:
+            - type (str): Layer type.
+            - layer args: Args needed to instantiate an activation layer.
+
+    Returns:
+        nn.Module: Created activation layer.
+    """
+    return build_from_cfg(cfg, ACTIVATION_LAYERS)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/context_block.py b/annotator/uniformer/mmcv/cnn/bricks/context_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..d60fdb904c749ce3b251510dff3cc63cea70d42e
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/context_block.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+
+from ..utils import constant_init, kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+def last_zero_init(m):
+    if isinstance(m, nn.Sequential):
+        constant_init(m[-1], val=0)
+    else:
+        constant_init(m, val=0)
+
+
+@PLUGIN_LAYERS.register_module()
+class ContextBlock(nn.Module):
+    """ContextBlock module in GCNet.
+
+    See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
+    (https://arxiv.org/abs/1904.11492) for details.
+
+    Args:
+        in_channels (int): Channels of the input feature map.
+        ratio (float): Ratio of channels of transform bottleneck
+        pooling_type (str): Pooling method for context modeling.
+            Options are 'att' and 'avg', stand for attention pooling and
+            average pooling respectively. Default: 'att'.
+        fusion_types (Sequence[str]): Fusion method for feature fusion,
+            Options are 'channels_add', 'channel_mul', stand for channelwise
+            addition and multiplication respectively. Default: ('channel_add',)
+    """
+
+    _abbr_ = 'context_block'
+
+    def __init__(self,
+                 in_channels,
+                 ratio,
+                 pooling_type='att',
+                 fusion_types=('channel_add', )):
+        super(ContextBlock, self).__init__()
+        assert pooling_type in ['avg', 'att']
+        assert isinstance(fusion_types, (list, tuple))
+        valid_fusion_types = ['channel_add', 'channel_mul']
+        assert all([f in valid_fusion_types for f in fusion_types])
+        assert len(fusion_types) > 0, 'at least one fusion should be used'
+        self.in_channels = in_channels
+        self.ratio = ratio
+        self.planes = int(in_channels * ratio)
+        self.pooling_type = pooling_type
+        self.fusion_types = fusion_types
+        if pooling_type == 'att':
+            self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
+            self.softmax = nn.Softmax(dim=2)
+        else:
+            self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        if 'channel_add' in fusion_types:
+            self.channel_add_conv = nn.Sequential(
+                nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+                nn.LayerNorm([self.planes, 1, 1]),
+                nn.ReLU(inplace=True),  # yapf: disable
+                nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+        else:
+            self.channel_add_conv = None
+        if 'channel_mul' in fusion_types:
+            self.channel_mul_conv = nn.Sequential(
+                nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
+                nn.LayerNorm([self.planes, 1, 1]),
+                nn.ReLU(inplace=True),  # yapf: disable
+                nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
+        else:
+            self.channel_mul_conv = None
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        if self.pooling_type == 'att':
+            kaiming_init(self.conv_mask, mode='fan_in')
+            self.conv_mask.inited = True
+
+        if self.channel_add_conv is not None:
+            last_zero_init(self.channel_add_conv)
+        if self.channel_mul_conv is not None:
+            last_zero_init(self.channel_mul_conv)
+
+    def spatial_pool(self, x):
+        batch, channel, height, width = x.size()
+        if self.pooling_type == 'att':
+            input_x = x
+            # [N, C, H * W]
+            input_x = input_x.view(batch, channel, height * width)
+            # [N, 1, C, H * W]
+            input_x = input_x.unsqueeze(1)
+            # [N, 1, H, W]
+            context_mask = self.conv_mask(x)
+            # [N, 1, H * W]
+            context_mask = context_mask.view(batch, 1, height * width)
+            # [N, 1, H * W]
+            context_mask = self.softmax(context_mask)
+            # [N, 1, H * W, 1]
+            context_mask = context_mask.unsqueeze(-1)
+            # [N, 1, C, 1]
+            context = torch.matmul(input_x, context_mask)
+            # [N, C, 1, 1]
+            context = context.view(batch, channel, 1, 1)
+        else:
+            # [N, C, 1, 1]
+            context = self.avg_pool(x)
+
+        return context
+
+    def forward(self, x):
+        # [N, C, 1, 1]
+        context = self.spatial_pool(x)
+
+        out = x
+        if self.channel_mul_conv is not None:
+            # [N, C, 1, 1]
+            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
+            out = out * channel_mul_term
+        if self.channel_add_conv is not None:
+            # [N, C, 1, 1]
+            channel_add_term = self.channel_add_conv(context)
+            out = out + channel_add_term
+
+        return out
diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv.py b/annotator/uniformer/mmcv/cnn/bricks/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf54491997a48ac3e7fadc4183ab7bf3e831024c
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/conv.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+
+from .registry import CONV_LAYERS
+
+CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
+CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
+CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
+CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
+
+
+def build_conv_layer(cfg, *args, **kwargs):
+    """Build convolution layer.
+
+    Args:
+        cfg (None or dict): The conv layer config, which should contain:
+            - type (str): Layer type.
+            - layer args: Args needed to instantiate an conv layer.
+        args (argument list): Arguments passed to the `__init__`
+            method of the corresponding conv layer.
+        kwargs (keyword arguments): Keyword arguments passed to the `__init__`
+            method of the corresponding conv layer.
+
+    Returns:
+        nn.Module: Created conv layer.
+    """
+    if cfg is None:
+        cfg_ = dict(type='Conv2d')
+    else:
+        if not isinstance(cfg, dict):
+            raise TypeError('cfg must be a dict')
+        if 'type' not in cfg:
+            raise KeyError('the cfg dict must contain the key "type"')
+        cfg_ = cfg.copy()
+
+    layer_type = cfg_.pop('type')
+    if layer_type not in CONV_LAYERS:
+        raise KeyError(f'Unrecognized norm type {layer_type}')
+    else:
+        conv_layer = CONV_LAYERS.get(layer_type)
+
+    layer = conv_layer(*args, **kwargs, **cfg_)
+
+    return layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py b/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45e758ac6cf8dfb0382d072fe09125bc7e9b888
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/conv2d_adaptive_padding.py
@@ -0,0 +1,62 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+from torch import nn
+from torch.nn import functional as F
+
+from .registry import CONV_LAYERS
+
+
+@CONV_LAYERS.register_module()
+class Conv2dAdaptivePadding(nn.Conv2d):
+    """Implementation of 2D convolution in tensorflow with `padding` as "same",
+    which applies padding to input (if needed) so that input image gets fully
+    covered by filter and stride you specified. For stride 1, this will ensure
+    that output image size is same as input. For stride of 2, output dimensions
+    will be half, for example.
+
+    Args:
+        in_channels (int): Number of channels in the input image
+        out_channels (int): Number of channels produced by the convolution
+        kernel_size (int or tuple): Size of the convolving kernel
+        stride (int or tuple, optional): Stride of the convolution. Default: 1
+        padding (int or tuple, optional): Zero-padding added to both sides of
+            the input. Default: 0
+        dilation (int or tuple, optional): Spacing between kernel elements.
+            Default: 1
+        groups (int, optional): Number of blocked connections from input
+            channels to output channels. Default: 1
+        bias (bool, optional): If ``True``, adds a learnable bias to the
+            output. Default: ``True``
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 bias=True):
+        super().__init__(in_channels, out_channels, kernel_size, stride, 0,
+                         dilation, groups, bias)
+
+    def forward(self, x):
+        img_h, img_w = x.size()[-2:]
+        kernel_h, kernel_w = self.weight.size()[-2:]
+        stride_h, stride_w = self.stride
+        output_h = math.ceil(img_h / stride_h)
+        output_w = math.ceil(img_w / stride_w)
+        pad_h = (
+            max((output_h - 1) * self.stride[0] +
+                (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
+        pad_w = (
+            max((output_w - 1) * self.stride[1] +
+                (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
+        if pad_h > 0 or pad_w > 0:
+            x = F.pad(x, [
+                pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
+            ])
+        return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
+                        self.dilation, self.groups)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv_module.py b/annotator/uniformer/mmcv/cnn/bricks/conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..e60e7e62245071c77b652093fddebff3948d7c3e
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/conv_module.py
@@ -0,0 +1,206 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch.nn as nn
+
+from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm
+from ..utils import constant_init, kaiming_init
+from .activation import build_activation_layer
+from .conv import build_conv_layer
+from .norm import build_norm_layer
+from .padding import build_padding_layer
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class ConvModule(nn.Module):
+    """A conv block that bundles conv/norm/activation layers.
+
+    This block simplifies the usage of convolution layers, which are commonly
+    used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
+    It is based upon three build methods: `build_conv_layer()`,
+    `build_norm_layer()` and `build_activation_layer()`.
+
+    Besides, we add some additional features in this module.
+    1. Automatically set `bias` of the conv layer.
+    2. Spectral norm is supported.
+    3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
+    supports zero and circular padding, and we add "reflect" padding mode.
+
+    Args:
+        in_channels (int): Number of channels in the input feature map.
+            Same as that in ``nn._ConvNd``.
+        out_channels (int): Number of channels produced by the convolution.
+            Same as that in ``nn._ConvNd``.
+        kernel_size (int | tuple[int]): Size of the convolving kernel.
+            Same as that in ``nn._ConvNd``.
+        stride (int | tuple[int]): Stride of the convolution.
+            Same as that in ``nn._ConvNd``.
+        padding (int | tuple[int]): Zero-padding added to both sides of
+            the input. Same as that in ``nn._ConvNd``.
+        dilation (int | tuple[int]): Spacing between kernel elements.
+            Same as that in ``nn._ConvNd``.
+        groups (int): Number of blocked connections from input channels to
+            output channels. Same as that in ``nn._ConvNd``.
+        bias (bool | str): If specified as `auto`, it will be decided by the
+            norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
+            False. Default: "auto".
+        conv_cfg (dict): Config dict for convolution layer. Default: None,
+            which means using conv2d.
+        norm_cfg (dict): Config dict for normalization layer. Default: None.
+        act_cfg (dict): Config dict for activation layer.
+            Default: dict(type='ReLU').
+        inplace (bool): Whether to use inplace mode for activation.
+            Default: True.
+        with_spectral_norm (bool): Whether use spectral norm in conv module.
+            Default: False.
+        padding_mode (str): If the `padding_mode` has not been supported by
+            current `Conv2d` in PyTorch, we will use our own padding layer
+            instead. Currently, we support ['zeros', 'circular'] with official
+            implementation and ['reflect'] with our own implementation.
+            Default: 'zeros'.
+        order (tuple[str]): The order of conv/norm/activation layers. It is a
+            sequence of "conv", "norm" and "act". Common examples are
+            ("conv", "norm", "act") and ("act", "conv", "norm").
+            Default: ('conv', 'norm', 'act').
+    """
+
+    _abbr_ = 'conv_block'
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 bias='auto',
+                 conv_cfg=None,
+                 norm_cfg=None,
+                 act_cfg=dict(type='ReLU'),
+                 inplace=True,
+                 with_spectral_norm=False,
+                 padding_mode='zeros',
+                 order=('conv', 'norm', 'act')):
+        super(ConvModule, self).__init__()
+        assert conv_cfg is None or isinstance(conv_cfg, dict)
+        assert norm_cfg is None or isinstance(norm_cfg, dict)
+        assert act_cfg is None or isinstance(act_cfg, dict)
+        official_padding_mode = ['zeros', 'circular']
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        self.inplace = inplace
+        self.with_spectral_norm = with_spectral_norm
+        self.with_explicit_padding = padding_mode not in official_padding_mode
+        self.order = order
+        assert isinstance(self.order, tuple) and len(self.order) == 3
+        assert set(order) == set(['conv', 'norm', 'act'])
+
+        self.with_norm = norm_cfg is not None
+        self.with_activation = act_cfg is not None
+        # if the conv layer is before a norm layer, bias is unnecessary.
+        if bias == 'auto':
+            bias = not self.with_norm
+        self.with_bias = bias
+
+        if self.with_explicit_padding:
+            pad_cfg = dict(type=padding_mode)
+            self.padding_layer = build_padding_layer(pad_cfg, padding)
+
+        # reset padding to 0 for conv module
+        conv_padding = 0 if self.with_explicit_padding else padding
+        # build convolution layer
+        self.conv = build_conv_layer(
+            conv_cfg,
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=conv_padding,
+            dilation=dilation,
+            groups=groups,
+            bias=bias)
+        # export the attributes of self.conv to a higher level for convenience
+        self.in_channels = self.conv.in_channels
+        self.out_channels = self.conv.out_channels
+        self.kernel_size = self.conv.kernel_size
+        self.stride = self.conv.stride
+        self.padding = padding
+        self.dilation = self.conv.dilation
+        self.transposed = self.conv.transposed
+        self.output_padding = self.conv.output_padding
+        self.groups = self.conv.groups
+
+        if self.with_spectral_norm:
+            self.conv = nn.utils.spectral_norm(self.conv)
+
+        # build normalization layers
+        if self.with_norm:
+            # norm layer is after conv layer
+            if order.index('norm') > order.index('conv'):
+                norm_channels = out_channels
+            else:
+                norm_channels = in_channels
+            self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
+            self.add_module(self.norm_name, norm)
+            if self.with_bias:
+                if isinstance(norm, (_BatchNorm, _InstanceNorm)):
+                    warnings.warn(
+                        'Unnecessary conv bias before batch/instance norm')
+        else:
+            self.norm_name = None
+
+        # build activation layer
+        if self.with_activation:
+            act_cfg_ = act_cfg.copy()
+            # nn.Tanh has no 'inplace' argument
+            if act_cfg_['type'] not in [
+                    'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
+            ]:
+                act_cfg_.setdefault('inplace', inplace)
+            self.activate = build_activation_layer(act_cfg_)
+
+        # Use msra init by default
+        self.init_weights()
+
+    @property
+    def norm(self):
+        if self.norm_name:
+            return getattr(self, self.norm_name)
+        else:
+            return None
+
+    def init_weights(self):
+        # 1. It is mainly for customized conv layers with their own
+        #    initialization manners by calling their own ``init_weights()``,
+        #    and we do not want ConvModule to override the initialization.
+        # 2. For customized conv layers without their own initialization
+        #    manners (that is, they don't have their own ``init_weights()``)
+        #    and PyTorch's conv layers, they will be initialized by
+        #    this method with default ``kaiming_init``.
+        # Note: For PyTorch's conv layers, they will be overwritten by our
+        #    initialization implementation using default ``kaiming_init``.
+        if not hasattr(self.conv, 'init_weights'):
+            if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
+                nonlinearity = 'leaky_relu'
+                a = self.act_cfg.get('negative_slope', 0.01)
+            else:
+                nonlinearity = 'relu'
+                a = 0
+            kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
+        if self.with_norm:
+            constant_init(self.norm, 1, bias=0)
+
+    def forward(self, x, activate=True, norm=True):
+        for layer in self.order:
+            if layer == 'conv':
+                if self.with_explicit_padding:
+                    x = self.padding_layer(x)
+                x = self.conv(x)
+            elif layer == 'norm' and norm and self.with_norm:
+                x = self.norm(x)
+            elif layer == 'act' and activate and self.with_activation:
+                x = self.activate(x)
+        return x
diff --git a/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py b/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3941e27874993418b3b5708d5a7485f175ff9c8
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/conv_ws.py
@@ -0,0 +1,148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .registry import CONV_LAYERS
+
+
+def conv_ws_2d(input,
+               weight,
+               bias=None,
+               stride=1,
+               padding=0,
+               dilation=1,
+               groups=1,
+               eps=1e-5):
+    c_in = weight.size(0)
+    weight_flat = weight.view(c_in, -1)
+    mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+    std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+    weight = (weight - mean) / (std + eps)
+    return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
+
+
+@CONV_LAYERS.register_module('ConvWS')
+class ConvWS2d(nn.Conv2d):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 bias=True,
+                 eps=1e-5):
+        super(ConvWS2d, self).__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            bias=bias)
+        self.eps = eps
+
+    def forward(self, x):
+        return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
+                          self.dilation, self.groups, self.eps)
+
+
+@CONV_LAYERS.register_module(name='ConvAWS')
+class ConvAWS2d(nn.Conv2d):
+    """AWS (Adaptive Weight Standardization)
+
+    This is a variant of Weight Standardization
+    (https://arxiv.org/pdf/1903.10520.pdf)
+    It is used in DetectoRS to avoid NaN
+    (https://arxiv.org/pdf/2006.02334.pdf)
+
+    Args:
+        in_channels (int): Number of channels in the input image
+        out_channels (int): Number of channels produced by the convolution
+        kernel_size (int or tuple): Size of the conv kernel
+        stride (int or tuple, optional): Stride of the convolution. Default: 1
+        padding (int or tuple, optional): Zero-padding added to both sides of
+            the input. Default: 0
+        dilation (int or tuple, optional): Spacing between kernel elements.
+            Default: 1
+        groups (int, optional): Number of blocked connections from input
+            channels to output channels. Default: 1
+        bias (bool, optional): If set True, adds a learnable bias to the
+            output. Default: True
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 bias=True):
+        super().__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            bias=bias)
+        self.register_buffer('weight_gamma',
+                             torch.ones(self.out_channels, 1, 1, 1))
+        self.register_buffer('weight_beta',
+                             torch.zeros(self.out_channels, 1, 1, 1))
+
+    def _get_weight(self, weight):
+        weight_flat = weight.view(weight.size(0), -1)
+        mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+        std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+        weight = (weight - mean) / std
+        weight = self.weight_gamma * weight + self.weight_beta
+        return weight
+
+    def forward(self, x):
+        weight = self._get_weight(self.weight)
+        return F.conv2d(x, weight, self.bias, self.stride, self.padding,
+                        self.dilation, self.groups)
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        """Override default load function.
+
+        AWS overrides the function _load_from_state_dict to recover
+        weight_gamma and weight_beta if they are missing. If weight_gamma and
+        weight_beta are found in the checkpoint, this function will return
+        after super()._load_from_state_dict. Otherwise, it will compute the
+        mean and std of the pretrained weights and store them in weight_beta
+        and weight_gamma.
+        """
+
+        self.weight_gamma.data.fill_(-1)
+        local_missing_keys = []
+        super()._load_from_state_dict(state_dict, prefix, local_metadata,
+                                      strict, local_missing_keys,
+                                      unexpected_keys, error_msgs)
+        if self.weight_gamma.data.mean() > 0:
+            for k in local_missing_keys:
+                missing_keys.append(k)
+            return
+        weight = self.weight.data
+        weight_flat = weight.view(weight.size(0), -1)
+        mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
+        std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
+        self.weight_beta.data.copy_(mean)
+        self.weight_gamma.data.copy_(std)
+        missing_gamma_beta = [
+            k for k in local_missing_keys
+            if k.endswith('weight_gamma') or k.endswith('weight_beta')
+        ]
+        for k in missing_gamma_beta:
+            local_missing_keys.remove(k)
+        for k in local_missing_keys:
+            missing_keys.append(k)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py b/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..722d5d8d71f75486e2db3008907c4eadfca41d63
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/depthwise_separable_conv_module.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .conv_module import ConvModule
+
+
+class DepthwiseSeparableConvModule(nn.Module):
+    """Depthwise separable convolution module.
+
+    See https://arxiv.org/pdf/1704.04861.pdf for details.
+
+    This module can replace a ConvModule with the conv block replaced by two
+    conv block: depthwise conv block and pointwise conv block. The depthwise
+    conv block contains depthwise-conv/norm/activation layers. The pointwise
+    conv block contains pointwise-conv/norm/activation layers. It should be
+    noted that there will be norm/activation layer in the depthwise conv block
+    if `norm_cfg` and `act_cfg` are specified.
+
+    Args:
+        in_channels (int): Number of channels in the input feature map.
+            Same as that in ``nn._ConvNd``.
+        out_channels (int): Number of channels produced by the convolution.
+            Same as that in ``nn._ConvNd``.
+        kernel_size (int | tuple[int]): Size of the convolving kernel.
+            Same as that in ``nn._ConvNd``.
+        stride (int | tuple[int]): Stride of the convolution.
+            Same as that in ``nn._ConvNd``. Default: 1.
+        padding (int | tuple[int]): Zero-padding added to both sides of
+            the input. Same as that in ``nn._ConvNd``. Default: 0.
+        dilation (int | tuple[int]): Spacing between kernel elements.
+            Same as that in ``nn._ConvNd``. Default: 1.
+        norm_cfg (dict): Default norm config for both depthwise ConvModule and
+            pointwise ConvModule. Default: None.
+        act_cfg (dict): Default activation config for both depthwise ConvModule
+            and pointwise ConvModule. Default: dict(type='ReLU').
+        dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
+            'default', it will be the same as `norm_cfg`. Default: 'default'.
+        dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
+            'default', it will be the same as `act_cfg`. Default: 'default'.
+        pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
+            'default', it will be the same as `norm_cfg`. Default: 'default'.
+        pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
+            'default', it will be the same as `act_cfg`. Default: 'default'.
+        kwargs (optional): Other shared arguments for depthwise and pointwise
+            ConvModule. See ConvModule for ref.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 norm_cfg=None,
+                 act_cfg=dict(type='ReLU'),
+                 dw_norm_cfg='default',
+                 dw_act_cfg='default',
+                 pw_norm_cfg='default',
+                 pw_act_cfg='default',
+                 **kwargs):
+        super(DepthwiseSeparableConvModule, self).__init__()
+        assert 'groups' not in kwargs, 'groups should not be specified'
+
+        # if norm/activation config of depthwise/pointwise ConvModule is not
+        # specified, use default config.
+        dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
+        dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
+        pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
+        pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
+
+        # depthwise convolution
+        self.depthwise_conv = ConvModule(
+            in_channels,
+            in_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=in_channels,
+            norm_cfg=dw_norm_cfg,
+            act_cfg=dw_act_cfg,
+            **kwargs)
+
+        self.pointwise_conv = ConvModule(
+            in_channels,
+            out_channels,
+            1,
+            norm_cfg=pw_norm_cfg,
+            act_cfg=pw_act_cfg,
+            **kwargs)
+
+    def forward(self, x):
+        x = self.depthwise_conv(x)
+        x = self.pointwise_conv(x)
+        return x
diff --git a/annotator/uniformer/mmcv/cnn/bricks/drop.py b/annotator/uniformer/mmcv/cnn/bricks/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7b4fccd457a0d51fb10c789df3c8537fe7b67c1
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/drop.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from annotator.uniformer.mmcv import build_from_cfg
+from .registry import DROPOUT_LAYERS
+
+
+def drop_path(x, drop_prob=0., training=False):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of
+    residual blocks).
+
+    We follow the implementation
+    https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py  # noqa: E501
+    """
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    # handle tensors with different dimensions, not just 4D tensors.
+    shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
+    random_tensor = keep_prob + torch.rand(
+        shape, dtype=x.dtype, device=x.device)
+    output = x.div(keep_prob) * random_tensor.floor()
+    return output
+
+
+@DROPOUT_LAYERS.register_module()
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of
+    residual blocks).
+
+    We follow the implementation
+    https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py  # noqa: E501
+
+    Args:
+        drop_prob (float): Probability of the path to be zeroed. Default: 0.1
+    """
+
+    def __init__(self, drop_prob=0.1):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+
+@DROPOUT_LAYERS.register_module()
+class Dropout(nn.Dropout):
+    """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
+    ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
+    ``DropPath``
+
+    Args:
+        drop_prob (float): Probability of the elements to be
+            zeroed. Default: 0.5.
+        inplace (bool):  Do the operation inplace or not. Default: False.
+    """
+
+    def __init__(self, drop_prob=0.5, inplace=False):
+        super().__init__(p=drop_prob, inplace=inplace)
+
+
+def build_dropout(cfg, default_args=None):
+    """Builder for drop out layers."""
+    return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py b/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..988d9adf2f289ef223bd1c680a5ae1d3387f0269
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/generalized_attention.py
@@ -0,0 +1,412 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import kaiming_init
+from .registry import PLUGIN_LAYERS
+
+
+@PLUGIN_LAYERS.register_module()
+class GeneralizedAttention(nn.Module):
+    """GeneralizedAttention module.
+
+    See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
+    (https://arxiv.org/abs/1711.07971) for details.
+
+    Args:
+        in_channels (int): Channels of the input feature map.
+        spatial_range (int): The spatial range. -1 indicates no spatial range
+            constraint. Default: -1.
+        num_heads (int): The head number of empirical_attention module.
+            Default: 9.
+        position_embedding_dim (int): The position embedding dimension.
+            Default: -1.
+        position_magnitude (int): A multiplier acting on coord difference.
+            Default: 1.
+        kv_stride (int): The feature stride acting on key/value feature map.
+            Default: 2.
+        q_stride (int): The feature stride acting on query feature map.
+            Default: 1.
+        attention_type (str): A binary indicator string for indicating which
+            items in generalized empirical_attention module are used.
+            Default: '1111'.
+
+            - '1000' indicates 'query and key content' (appr - appr) item,
+            - '0100' indicates 'query content and relative position'
+              (appr - position) item,
+            - '0010' indicates 'key content only' (bias - appr) item,
+            - '0001' indicates 'relative position only' (bias - position) item.
+    """
+
+    _abbr_ = 'gen_attention_block'
+
+    def __init__(self,
+                 in_channels,
+                 spatial_range=-1,
+                 num_heads=9,
+                 position_embedding_dim=-1,
+                 position_magnitude=1,
+                 kv_stride=2,
+                 q_stride=1,
+                 attention_type='1111'):
+
+        super(GeneralizedAttention, self).__init__()
+
+        # hard range means local range for non-local operation
+        self.position_embedding_dim = (
+            position_embedding_dim
+            if position_embedding_dim > 0 else in_channels)
+
+        self.position_magnitude = position_magnitude
+        self.num_heads = num_heads
+        self.in_channels = in_channels
+        self.spatial_range = spatial_range
+        self.kv_stride = kv_stride
+        self.q_stride = q_stride
+        self.attention_type = [bool(int(_)) for _ in attention_type]
+        self.qk_embed_dim = in_channels // num_heads
+        out_c = self.qk_embed_dim * num_heads
+
+        if self.attention_type[0] or self.attention_type[1]:
+            self.query_conv = nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=out_c,
+                kernel_size=1,
+                bias=False)
+            self.query_conv.kaiming_init = True
+
+        if self.attention_type[0] or self.attention_type[2]:
+            self.key_conv = nn.Conv2d(
+                in_channels=in_channels,
+                out_channels=out_c,
+                kernel_size=1,
+                bias=False)
+            self.key_conv.kaiming_init = True
+
+        self.v_dim = in_channels // num_heads
+        self.value_conv = nn.Conv2d(
+            in_channels=in_channels,
+            out_channels=self.v_dim * num_heads,
+            kernel_size=1,
+            bias=False)
+        self.value_conv.kaiming_init = True
+
+        if self.attention_type[1] or self.attention_type[3]:
+            self.appr_geom_fc_x = nn.Linear(
+                self.position_embedding_dim // 2, out_c, bias=False)
+            self.appr_geom_fc_x.kaiming_init = True
+
+            self.appr_geom_fc_y = nn.Linear(
+                self.position_embedding_dim // 2, out_c, bias=False)
+            self.appr_geom_fc_y.kaiming_init = True
+
+        if self.attention_type[2]:
+            stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+            appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+            self.appr_bias = nn.Parameter(appr_bias_value)
+
+        if self.attention_type[3]:
+            stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
+            geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
+            self.geom_bias = nn.Parameter(geom_bias_value)
+
+        self.proj_conv = nn.Conv2d(
+            in_channels=self.v_dim * num_heads,
+            out_channels=in_channels,
+            kernel_size=1,
+            bias=True)
+        self.proj_conv.kaiming_init = True
+        self.gamma = nn.Parameter(torch.zeros(1))
+
+        if self.spatial_range >= 0:
+            # only works when non local is after 3*3 conv
+            if in_channels == 256:
+                max_len = 84
+            elif in_channels == 512:
+                max_len = 42
+
+            max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
+            local_constraint_map = np.ones(
+                (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
+            for iy in range(max_len):
+                for ix in range(max_len):
+                    local_constraint_map[
+                        iy, ix,
+                        max((iy - self.spatial_range) //
+                            self.kv_stride, 0):min((iy + self.spatial_range +
+                                                    1) // self.kv_stride +
+                                                   1, max_len),
+                        max((ix - self.spatial_range) //
+                            self.kv_stride, 0):min((ix + self.spatial_range +
+                                                    1) // self.kv_stride +
+                                                   1, max_len)] = 0
+
+            self.local_constraint_map = nn.Parameter(
+                torch.from_numpy(local_constraint_map).byte(),
+                requires_grad=False)
+
+        if self.q_stride > 1:
+            self.q_downsample = nn.AvgPool2d(
+                kernel_size=1, stride=self.q_stride)
+        else:
+            self.q_downsample = None
+
+        if self.kv_stride > 1:
+            self.kv_downsample = nn.AvgPool2d(
+                kernel_size=1, stride=self.kv_stride)
+        else:
+            self.kv_downsample = None
+
+        self.init_weights()
+
+    def get_position_embedding(self,
+                               h,
+                               w,
+                               h_kv,
+                               w_kv,
+                               q_stride,
+                               kv_stride,
+                               device,
+                               dtype,
+                               feat_dim,
+                               wave_length=1000):
+        # the default type of Tensor is float32, leading to type mismatch
+        # in fp16 mode. Cast it to support fp16 mode.
+        h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
+        h_idxs = h_idxs.view((h, 1)) * q_stride
+
+        w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
+        w_idxs = w_idxs.view((w, 1)) * q_stride
+
+        h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
+            device=device, dtype=dtype)
+        h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
+
+        w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
+            device=device, dtype=dtype)
+        w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
+
+        # (h, h_kv, 1)
+        h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
+        h_diff *= self.position_magnitude
+
+        # (w, w_kv, 1)
+        w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
+        w_diff *= self.position_magnitude
+
+        feat_range = torch.arange(0, feat_dim / 4).to(
+            device=device, dtype=dtype)
+
+        dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
+        dim_mat = dim_mat**((4. / feat_dim) * feat_range)
+        dim_mat = dim_mat.view((1, 1, -1))
+
+        embedding_x = torch.cat(
+            ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
+
+        embedding_y = torch.cat(
+            ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
+
+        return embedding_x, embedding_y
+
+    def forward(self, x_input):
+        num_heads = self.num_heads
+
+        # use empirical_attention
+        if self.q_downsample is not None:
+            x_q = self.q_downsample(x_input)
+        else:
+            x_q = x_input
+        n, _, h, w = x_q.shape
+
+        if self.kv_downsample is not None:
+            x_kv = self.kv_downsample(x_input)
+        else:
+            x_kv = x_input
+        _, _, h_kv, w_kv = x_kv.shape
+
+        if self.attention_type[0] or self.attention_type[1]:
+            proj_query = self.query_conv(x_q).view(
+                (n, num_heads, self.qk_embed_dim, h * w))
+            proj_query = proj_query.permute(0, 1, 3, 2)
+
+        if self.attention_type[0] or self.attention_type[2]:
+            proj_key = self.key_conv(x_kv).view(
+                (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
+
+        if self.attention_type[1] or self.attention_type[3]:
+            position_embed_x, position_embed_y = self.get_position_embedding(
+                h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
+                x_input.device, x_input.dtype, self.position_embedding_dim)
+            # (n, num_heads, w, w_kv, dim)
+            position_feat_x = self.appr_geom_fc_x(position_embed_x).\
+                view(1, w, w_kv, num_heads, self.qk_embed_dim).\
+                permute(0, 3, 1, 2, 4).\
+                repeat(n, 1, 1, 1, 1)
+
+            # (n, num_heads, h, h_kv, dim)
+            position_feat_y = self.appr_geom_fc_y(position_embed_y).\
+                view(1, h, h_kv, num_heads, self.qk_embed_dim).\
+                permute(0, 3, 1, 2, 4).\
+                repeat(n, 1, 1, 1, 1)
+
+            position_feat_x /= math.sqrt(2)
+            position_feat_y /= math.sqrt(2)
+
+        # accelerate for saliency only
+        if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
+            appr_bias = self.appr_bias.\
+                view(1, num_heads, 1, self.qk_embed_dim).\
+                repeat(n, 1, 1, 1)
+
+            energy = torch.matmul(appr_bias, proj_key).\
+                view(n, num_heads, 1, h_kv * w_kv)
+
+            h = 1
+            w = 1
+        else:
+            # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
+            if not self.attention_type[0]:
+                energy = torch.zeros(
+                    n,
+                    num_heads,
+                    h,
+                    w,
+                    h_kv,
+                    w_kv,
+                    dtype=x_input.dtype,
+                    device=x_input.device)
+
+            # attention_type[0]: appr - appr
+            # attention_type[1]: appr - position
+            # attention_type[2]: bias - appr
+            # attention_type[3]: bias - position
+            if self.attention_type[0] or self.attention_type[2]:
+                if self.attention_type[0] and self.attention_type[2]:
+                    appr_bias = self.appr_bias.\
+                        view(1, num_heads, 1, self.qk_embed_dim)
+                    energy = torch.matmul(proj_query + appr_bias, proj_key).\
+                        view(n, num_heads, h, w, h_kv, w_kv)
+
+                elif self.attention_type[0]:
+                    energy = torch.matmul(proj_query, proj_key).\
+                        view(n, num_heads, h, w, h_kv, w_kv)
+
+                elif self.attention_type[2]:
+                    appr_bias = self.appr_bias.\
+                        view(1, num_heads, 1, self.qk_embed_dim).\
+                        repeat(n, 1, 1, 1)
+
+                    energy += torch.matmul(appr_bias, proj_key).\
+                        view(n, num_heads, 1, 1, h_kv, w_kv)
+
+            if self.attention_type[1] or self.attention_type[3]:
+                if self.attention_type[1] and self.attention_type[3]:
+                    geom_bias = self.geom_bias.\
+                        view(1, num_heads, 1, self.qk_embed_dim)
+
+                    proj_query_reshape = (proj_query + geom_bias).\
+                        view(n, num_heads, h, w, self.qk_embed_dim)
+
+                    energy_x = torch.matmul(
+                        proj_query_reshape.permute(0, 1, 3, 2, 4),
+                        position_feat_x.permute(0, 1, 2, 4, 3))
+                    energy_x = energy_x.\
+                        permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+                    energy_y = torch.matmul(
+                        proj_query_reshape,
+                        position_feat_y.permute(0, 1, 2, 4, 3))
+                    energy_y = energy_y.unsqueeze(5)
+
+                    energy += energy_x + energy_y
+
+                elif self.attention_type[1]:
+                    proj_query_reshape = proj_query.\
+                        view(n, num_heads, h, w, self.qk_embed_dim)
+                    proj_query_reshape = proj_query_reshape.\
+                        permute(0, 1, 3, 2, 4)
+                    position_feat_x_reshape = position_feat_x.\
+                        permute(0, 1, 2, 4, 3)
+                    position_feat_y_reshape = position_feat_y.\
+                        permute(0, 1, 2, 4, 3)
+
+                    energy_x = torch.matmul(proj_query_reshape,
+                                            position_feat_x_reshape)
+                    energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
+
+                    energy_y = torch.matmul(proj_query_reshape,
+                                            position_feat_y_reshape)
+                    energy_y = energy_y.unsqueeze(5)
+
+                    energy += energy_x + energy_y
+
+                elif self.attention_type[3]:
+                    geom_bias = self.geom_bias.\
+                        view(1, num_heads, self.qk_embed_dim, 1).\
+                        repeat(n, 1, 1, 1)
+
+                    position_feat_x_reshape = position_feat_x.\
+                        view(n, num_heads, w*w_kv, self.qk_embed_dim)
+
+                    position_feat_y_reshape = position_feat_y.\
+                        view(n, num_heads, h * h_kv, self.qk_embed_dim)
+
+                    energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
+                    energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
+
+                    energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
+                    energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
+
+                    energy += energy_x + energy_y
+
+            energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
+
+        if self.spatial_range >= 0:
+            cur_local_constraint_map = \
+                self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
+                contiguous().\
+                view(1, 1, h*w, h_kv*w_kv)
+
+            energy = energy.masked_fill_(cur_local_constraint_map,
+                                         float('-inf'))
+
+        attention = F.softmax(energy, 3)
+
+        proj_value = self.value_conv(x_kv)
+        proj_value_reshape = proj_value.\
+            view((n, num_heads, self.v_dim, h_kv * w_kv)).\
+            permute(0, 1, 3, 2)
+
+        out = torch.matmul(attention, proj_value_reshape).\
+            permute(0, 1, 3, 2).\
+            contiguous().\
+            view(n, self.v_dim * self.num_heads, h, w)
+
+        out = self.proj_conv(out)
+
+        # output is downsampled, upsample back to input size
+        if self.q_downsample is not None:
+            out = F.interpolate(
+                out,
+                size=x_input.shape[2:],
+                mode='bilinear',
+                align_corners=False)
+
+        out = self.gamma * out + x_input
+        return out
+
+    def init_weights(self):
+        for m in self.modules():
+            if hasattr(m, 'kaiming_init') and m.kaiming_init:
+                kaiming_init(
+                    m,
+                    mode='fan_in',
+                    nonlinearity='leaky_relu',
+                    bias=0,
+                    distribution='uniform',
+                    a=1)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py b/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py
new file mode 100644
index 0000000000000000000000000000000000000000..30b1a3d6580cf0360710426fbea1f05acdf07b4b
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/hsigmoid.py
@@ -0,0 +1,34 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSigmoid(nn.Module):
+    """Hard Sigmoid Module. Apply the hard sigmoid function:
+    Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
+    Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
+
+    Args:
+        bias (float): Bias of the input feature map. Default: 1.0.
+        divisor (float): Divisor of the input feature map. Default: 2.0.
+        min_value (float): Lower bound value. Default: 0.0.
+        max_value (float): Upper bound value. Default: 1.0.
+
+    Returns:
+        Tensor: The output tensor.
+    """
+
+    def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
+        super(HSigmoid, self).__init__()
+        self.bias = bias
+        self.divisor = divisor
+        assert self.divisor != 0
+        self.min_value = min_value
+        self.max_value = max_value
+
+    def forward(self, x):
+        x = (x + self.bias) / self.divisor
+
+        return x.clamp_(self.min_value, self.max_value)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/hswish.py b/annotator/uniformer/mmcv/cnn/bricks/hswish.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e0c090ff037c99ee6c5c84c4592e87beae02208
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/hswish.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class HSwish(nn.Module):
+    """Hard Swish Module.
+
+    This module applies the hard swish function:
+
+    .. math::
+        Hswish(x) = x * ReLU6(x + 3) / 6
+
+    Args:
+        inplace (bool): can optionally do the operation in-place.
+            Default: False.
+
+    Returns:
+        Tensor: The output tensor.
+    """
+
+    def __init__(self, inplace=False):
+        super(HSwish, self).__init__()
+        self.act = nn.ReLU6(inplace)
+
+    def forward(self, x):
+        return x * self.act(x + 3) / 6
diff --git a/annotator/uniformer/mmcv/cnn/bricks/non_local.py b/annotator/uniformer/mmcv/cnn/bricks/non_local.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d00155ef275c1201ea66bba30470a1785cc5d7
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/non_local.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta
+
+import torch
+import torch.nn as nn
+
+from ..utils import constant_init, normal_init
+from .conv_module import ConvModule
+from .registry import PLUGIN_LAYERS
+
+
+class _NonLocalNd(nn.Module, metaclass=ABCMeta):
+    """Basic Non-local module.
+
+    This module is proposed in
+    "Non-local Neural Networks"
+    Paper reference: https://arxiv.org/abs/1711.07971
+    Code reference: https://github.com/AlexHex7/Non-local_pytorch
+
+    Args:
+        in_channels (int): Channels of the input feature map.
+        reduction (int): Channel reduction ratio. Default: 2.
+        use_scale (bool): Whether to scale pairwise_weight by
+            `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
+            Default: True.
+        conv_cfg (None | dict): The config dict for convolution layers.
+            If not specified, it will use `nn.Conv2d` for convolution layers.
+            Default: None.
+        norm_cfg (None | dict): The config dict for normalization layers.
+            Default: None. (This parameter is only applicable to conv_out.)
+        mode (str): Options are `gaussian`, `concatenation`,
+            `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 reduction=2,
+                 use_scale=True,
+                 conv_cfg=None,
+                 norm_cfg=None,
+                 mode='embedded_gaussian',
+                 **kwargs):
+        super(_NonLocalNd, self).__init__()
+        self.in_channels = in_channels
+        self.reduction = reduction
+        self.use_scale = use_scale
+        self.inter_channels = max(in_channels // reduction, 1)
+        self.mode = mode
+
+        if mode not in [
+                'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
+        ]:
+            raise ValueError("Mode should be in 'gaussian', 'concatenation', "
+                             f"'embedded_gaussian' or 'dot_product', but got "
+                             f'{mode} instead.')
+
+        # g, theta, phi are defaulted as `nn.ConvNd`.
+        # Here we use ConvModule for potential usage.
+        self.g = ConvModule(
+            self.in_channels,
+            self.inter_channels,
+            kernel_size=1,
+            conv_cfg=conv_cfg,
+            act_cfg=None)
+        self.conv_out = ConvModule(
+            self.inter_channels,
+            self.in_channels,
+            kernel_size=1,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=None)
+
+        if self.mode != 'gaussian':
+            self.theta = ConvModule(
+                self.in_channels,
+                self.inter_channels,
+                kernel_size=1,
+                conv_cfg=conv_cfg,
+                act_cfg=None)
+            self.phi = ConvModule(
+                self.in_channels,
+                self.inter_channels,
+                kernel_size=1,
+                conv_cfg=conv_cfg,
+                act_cfg=None)
+
+        if self.mode == 'concatenation':
+            self.concat_project = ConvModule(
+                self.inter_channels * 2,
+                1,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False,
+                act_cfg=dict(type='ReLU'))
+
+        self.init_weights(**kwargs)
+
+    def init_weights(self, std=0.01, zeros_init=True):
+        if self.mode != 'gaussian':
+            for m in [self.g, self.theta, self.phi]:
+                normal_init(m.conv, std=std)
+        else:
+            normal_init(self.g.conv, std=std)
+        if zeros_init:
+            if self.conv_out.norm_cfg is None:
+                constant_init(self.conv_out.conv, 0)
+            else:
+                constant_init(self.conv_out.norm, 0)
+        else:
+            if self.conv_out.norm_cfg is None:
+                normal_init(self.conv_out.conv, std=std)
+            else:
+                normal_init(self.conv_out.norm, std=std)
+
+    def gaussian(self, theta_x, phi_x):
+        # NonLocal1d pairwise_weight: [N, H, H]
+        # NonLocal2d pairwise_weight: [N, HxW, HxW]
+        # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+        pairwise_weight = torch.matmul(theta_x, phi_x)
+        pairwise_weight = pairwise_weight.softmax(dim=-1)
+        return pairwise_weight
+
+    def embedded_gaussian(self, theta_x, phi_x):
+        # NonLocal1d pairwise_weight: [N, H, H]
+        # NonLocal2d pairwise_weight: [N, HxW, HxW]
+        # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+        pairwise_weight = torch.matmul(theta_x, phi_x)
+        if self.use_scale:
+            # theta_x.shape[-1] is `self.inter_channels`
+            pairwise_weight /= theta_x.shape[-1]**0.5
+        pairwise_weight = pairwise_weight.softmax(dim=-1)
+        return pairwise_weight
+
+    def dot_product(self, theta_x, phi_x):
+        # NonLocal1d pairwise_weight: [N, H, H]
+        # NonLocal2d pairwise_weight: [N, HxW, HxW]
+        # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+        pairwise_weight = torch.matmul(theta_x, phi_x)
+        pairwise_weight /= pairwise_weight.shape[-1]
+        return pairwise_weight
+
+    def concatenation(self, theta_x, phi_x):
+        # NonLocal1d pairwise_weight: [N, H, H]
+        # NonLocal2d pairwise_weight: [N, HxW, HxW]
+        # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+        h = theta_x.size(2)
+        w = phi_x.size(3)
+        theta_x = theta_x.repeat(1, 1, 1, w)
+        phi_x = phi_x.repeat(1, 1, h, 1)
+
+        concat_feature = torch.cat([theta_x, phi_x], dim=1)
+        pairwise_weight = self.concat_project(concat_feature)
+        n, _, h, w = pairwise_weight.size()
+        pairwise_weight = pairwise_weight.view(n, h, w)
+        pairwise_weight /= pairwise_weight.shape[-1]
+
+        return pairwise_weight
+
+    def forward(self, x):
+        # Assume `reduction = 1`, then `inter_channels = C`
+        # or `inter_channels = C` when `mode="gaussian"`
+
+        # NonLocal1d x: [N, C, H]
+        # NonLocal2d x: [N, C, H, W]
+        # NonLocal3d x: [N, C, T, H, W]
+        n = x.size(0)
+
+        # NonLocal1d g_x: [N, H, C]
+        # NonLocal2d g_x: [N, HxW, C]
+        # NonLocal3d g_x: [N, TxHxW, C]
+        g_x = self.g(x).view(n, self.inter_channels, -1)
+        g_x = g_x.permute(0, 2, 1)
+
+        # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
+        # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+        # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
+        if self.mode == 'gaussian':
+            theta_x = x.view(n, self.in_channels, -1)
+            theta_x = theta_x.permute(0, 2, 1)
+            if self.sub_sample:
+                phi_x = self.phi(x).view(n, self.in_channels, -1)
+            else:
+                phi_x = x.view(n, self.in_channels, -1)
+        elif self.mode == 'concatenation':
+            theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+            phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+        else:
+            theta_x = self.theta(x).view(n, self.inter_channels, -1)
+            theta_x = theta_x.permute(0, 2, 1)
+            phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+        pairwise_func = getattr(self, self.mode)
+        # NonLocal1d pairwise_weight: [N, H, H]
+        # NonLocal2d pairwise_weight: [N, HxW, HxW]
+        # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
+        pairwise_weight = pairwise_func(theta_x, phi_x)
+
+        # NonLocal1d y: [N, H, C]
+        # NonLocal2d y: [N, HxW, C]
+        # NonLocal3d y: [N, TxHxW, C]
+        y = torch.matmul(pairwise_weight, g_x)
+        # NonLocal1d y: [N, C, H]
+        # NonLocal2d y: [N, C, H, W]
+        # NonLocal3d y: [N, C, T, H, W]
+        y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+                                                    *x.size()[2:])
+
+        output = x + self.conv_out(y)
+
+        return output
+
+
+class NonLocal1d(_NonLocalNd):
+    """1D Non-local module.
+
+    Args:
+        in_channels (int): Same as `NonLocalND`.
+        sub_sample (bool): Whether to apply max pooling after pairwise
+            function (Note that the `sub_sample` is applied on spatial only).
+            Default: False.
+        conv_cfg (None | dict): Same as `NonLocalND`.
+            Default: dict(type='Conv1d').
+    """
+
+    def __init__(self,
+                 in_channels,
+                 sub_sample=False,
+                 conv_cfg=dict(type='Conv1d'),
+                 **kwargs):
+        super(NonLocal1d, self).__init__(
+            in_channels, conv_cfg=conv_cfg, **kwargs)
+
+        self.sub_sample = sub_sample
+
+        if sub_sample:
+            max_pool_layer = nn.MaxPool1d(kernel_size=2)
+            self.g = nn.Sequential(self.g, max_pool_layer)
+            if self.mode != 'gaussian':
+                self.phi = nn.Sequential(self.phi, max_pool_layer)
+            else:
+                self.phi = max_pool_layer
+
+
+@PLUGIN_LAYERS.register_module()
+class NonLocal2d(_NonLocalNd):
+    """2D Non-local module.
+
+    Args:
+        in_channels (int): Same as `NonLocalND`.
+        sub_sample (bool): Whether to apply max pooling after pairwise
+            function (Note that the `sub_sample` is applied on spatial only).
+            Default: False.
+        conv_cfg (None | dict): Same as `NonLocalND`.
+            Default: dict(type='Conv2d').
+    """
+
+    _abbr_ = 'nonlocal_block'
+
+    def __init__(self,
+                 in_channels,
+                 sub_sample=False,
+                 conv_cfg=dict(type='Conv2d'),
+                 **kwargs):
+        super(NonLocal2d, self).__init__(
+            in_channels, conv_cfg=conv_cfg, **kwargs)
+
+        self.sub_sample = sub_sample
+
+        if sub_sample:
+            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
+            self.g = nn.Sequential(self.g, max_pool_layer)
+            if self.mode != 'gaussian':
+                self.phi = nn.Sequential(self.phi, max_pool_layer)
+            else:
+                self.phi = max_pool_layer
+
+
+class NonLocal3d(_NonLocalNd):
+    """3D Non-local module.
+
+    Args:
+        in_channels (int): Same as `NonLocalND`.
+        sub_sample (bool): Whether to apply max pooling after pairwise
+            function (Note that the `sub_sample` is applied on spatial only).
+            Default: False.
+        conv_cfg (None | dict): Same as `NonLocalND`.
+            Default: dict(type='Conv3d').
+    """
+
+    def __init__(self,
+                 in_channels,
+                 sub_sample=False,
+                 conv_cfg=dict(type='Conv3d'),
+                 **kwargs):
+        super(NonLocal3d, self).__init__(
+            in_channels, conv_cfg=conv_cfg, **kwargs)
+        self.sub_sample = sub_sample
+
+        if sub_sample:
+            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
+            self.g = nn.Sequential(self.g, max_pool_layer)
+            if self.mode != 'gaussian':
+                self.phi = nn.Sequential(self.phi, max_pool_layer)
+            else:
+                self.phi = max_pool_layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/norm.py b/annotator/uniformer/mmcv/cnn/bricks/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..408f4b42731b19a3beeef68b6a5e610d0bbc18b3
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/norm.py
@@ -0,0 +1,144 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+
+import torch.nn as nn
+
+from annotator.uniformer.mmcv.utils import is_tuple_of
+from annotator.uniformer.mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
+from .registry import NORM_LAYERS
+
+NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
+NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
+NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
+NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
+NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
+NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
+NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
+NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
+NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
+
+
+def infer_abbr(class_type):
+    """Infer abbreviation from the class name.
+
+    When we build a norm layer with `build_norm_layer()`, we want to preserve
+    the norm type in variable names, e.g, self.bn1, self.gn. This method will
+    infer the abbreviation to map class types to abbreviations.
+
+    Rule 1: If the class has the property "_abbr_", return the property.
+    Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
+    InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
+    "in" respectively.
+    Rule 3: If the class name contains "batch", "group", "layer" or "instance",
+    the abbreviation of this layer will be "bn", "gn", "ln" and "in"
+    respectively.
+    Rule 4: Otherwise, the abbreviation falls back to "norm".
+
+    Args:
+        class_type (type): The norm layer type.
+
+    Returns:
+        str: The inferred abbreviation.
+    """
+    if not inspect.isclass(class_type):
+        raise TypeError(
+            f'class_type must be a type, but got {type(class_type)}')
+    if hasattr(class_type, '_abbr_'):
+        return class_type._abbr_
+    if issubclass(class_type, _InstanceNorm):  # IN is a subclass of BN
+        return 'in'
+    elif issubclass(class_type, _BatchNorm):
+        return 'bn'
+    elif issubclass(class_type, nn.GroupNorm):
+        return 'gn'
+    elif issubclass(class_type, nn.LayerNorm):
+        return 'ln'
+    else:
+        class_name = class_type.__name__.lower()
+        if 'batch' in class_name:
+            return 'bn'
+        elif 'group' in class_name:
+            return 'gn'
+        elif 'layer' in class_name:
+            return 'ln'
+        elif 'instance' in class_name:
+            return 'in'
+        else:
+            return 'norm_layer'
+
+
+def build_norm_layer(cfg, num_features, postfix=''):
+    """Build normalization layer.
+
+    Args:
+        cfg (dict): The norm layer config, which should contain:
+
+            - type (str): Layer type.
+            - layer args: Args needed to instantiate a norm layer.
+            - requires_grad (bool, optional): Whether stop gradient updates.
+        num_features (int): Number of input channels.
+        postfix (int | str): The postfix to be appended into norm abbreviation
+            to create named layer.
+
+    Returns:
+        (str, nn.Module): The first element is the layer name consisting of
+            abbreviation and postfix, e.g., bn1, gn. The second element is the
+            created norm layer.
+    """
+    if not isinstance(cfg, dict):
+        raise TypeError('cfg must be a dict')
+    if 'type' not in cfg:
+        raise KeyError('the cfg dict must contain the key "type"')
+    cfg_ = cfg.copy()
+
+    layer_type = cfg_.pop('type')
+    if layer_type not in NORM_LAYERS:
+        raise KeyError(f'Unrecognized norm type {layer_type}')
+
+    norm_layer = NORM_LAYERS.get(layer_type)
+    abbr = infer_abbr(norm_layer)
+
+    assert isinstance(postfix, (int, str))
+    name = abbr + str(postfix)
+
+    requires_grad = cfg_.pop('requires_grad', True)
+    cfg_.setdefault('eps', 1e-5)
+    if layer_type != 'GN':
+        layer = norm_layer(num_features, **cfg_)
+        if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
+            layer._specify_ddp_gpu_num(1)
+    else:
+        assert 'num_groups' in cfg_
+        layer = norm_layer(num_channels=num_features, **cfg_)
+
+    for param in layer.parameters():
+        param.requires_grad = requires_grad
+
+    return name, layer
+
+
+def is_norm(layer, exclude=None):
+    """Check if a layer is a normalization layer.
+
+    Args:
+        layer (nn.Module): The layer to be checked.
+        exclude (type | tuple[type]): Types to be excluded.
+
+    Returns:
+        bool: Whether the layer is a norm layer.
+    """
+    if exclude is not None:
+        if not isinstance(exclude, tuple):
+            exclude = (exclude, )
+        if not is_tuple_of(exclude, type):
+            raise TypeError(
+                f'"exclude" must be either None or type or a tuple of types, '
+                f'but got {type(exclude)}: {exclude}')
+
+    if exclude and isinstance(layer, exclude):
+        return False
+
+    all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
+    return isinstance(layer, all_norm_bases)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/padding.py b/annotator/uniformer/mmcv/cnn/bricks/padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4ac6b28a1789bd551c613a7d3e7b622433ac7ec
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/padding.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+
+from .registry import PADDING_LAYERS
+
+PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
+PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
+PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
+
+
+def build_padding_layer(cfg, *args, **kwargs):
+    """Build padding layer.
+
+    Args:
+        cfg (None or dict): The padding layer config, which should contain:
+            - type (str): Layer type.
+            - layer args: Args needed to instantiate a padding layer.
+
+    Returns:
+        nn.Module: Created padding layer.
+    """
+    if not isinstance(cfg, dict):
+        raise TypeError('cfg must be a dict')
+    if 'type' not in cfg:
+        raise KeyError('the cfg dict must contain the key "type"')
+
+    cfg_ = cfg.copy()
+    padding_type = cfg_.pop('type')
+    if padding_type not in PADDING_LAYERS:
+        raise KeyError(f'Unrecognized padding type {padding_type}.')
+    else:
+        padding_layer = PADDING_LAYERS.get(padding_type)
+
+    layer = padding_layer(*args, **kwargs, **cfg_)
+
+    return layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/plugin.py b/annotator/uniformer/mmcv/cnn/bricks/plugin.py
new file mode 100644
index 0000000000000000000000000000000000000000..07c010d4053174dd41107aa654ea67e82b46a25c
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/plugin.py
@@ -0,0 +1,88 @@
+import inspect
+import platform
+
+from .registry import PLUGIN_LAYERS
+
+if platform.system() == 'Windows':
+    import regex as re
+else:
+    import re
+
+
+def infer_abbr(class_type):
+    """Infer abbreviation from the class name.
+
+    This method will infer the abbreviation to map class types to
+    abbreviations.
+
+    Rule 1: If the class has the property "abbr", return the property.
+    Rule 2: Otherwise, the abbreviation falls back to snake case of class
+    name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
+
+    Args:
+        class_type (type): The norm layer type.
+
+    Returns:
+        str: The inferred abbreviation.
+    """
+
+    def camel2snack(word):
+        """Convert camel case word into snack case.
+
+        Modified from `inflection lib
+        <https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
+
+        Example::
+
+            >>> camel2snack("FancyBlock")
+            'fancy_block'
+        """
+
+        word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
+        word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
+        word = word.replace('-', '_')
+        return word.lower()
+
+    if not inspect.isclass(class_type):
+        raise TypeError(
+            f'class_type must be a type, but got {type(class_type)}')
+    if hasattr(class_type, '_abbr_'):
+        return class_type._abbr_
+    else:
+        return camel2snack(class_type.__name__)
+
+
+def build_plugin_layer(cfg, postfix='', **kwargs):
+    """Build plugin layer.
+
+    Args:
+        cfg (None or dict): cfg should contain:
+            type (str): identify plugin layer type.
+            layer args: args needed to instantiate a plugin layer.
+        postfix (int, str): appended into norm abbreviation to
+            create named layer. Default: ''.
+
+    Returns:
+        tuple[str, nn.Module]:
+            name (str): abbreviation + postfix
+            layer (nn.Module): created plugin layer
+    """
+    if not isinstance(cfg, dict):
+        raise TypeError('cfg must be a dict')
+    if 'type' not in cfg:
+        raise KeyError('the cfg dict must contain the key "type"')
+    cfg_ = cfg.copy()
+
+    layer_type = cfg_.pop('type')
+    if layer_type not in PLUGIN_LAYERS:
+        raise KeyError(f'Unrecognized plugin type {layer_type}')
+
+    plugin_layer = PLUGIN_LAYERS.get(layer_type)
+    abbr = infer_abbr(plugin_layer)
+
+    assert isinstance(postfix, (int, str))
+    name = abbr + str(postfix)
+
+    layer = plugin_layer(**kwargs, **cfg_)
+
+    return name, layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/registry.py b/annotator/uniformer/mmcv/cnn/bricks/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..39eabc58db4b5954478a2ac1ab91cea5e45ab055
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/registry.py
@@ -0,0 +1,16 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from annotator.uniformer.mmcv.utils import Registry
+
+CONV_LAYERS = Registry('conv layer')
+NORM_LAYERS = Registry('norm layer')
+ACTIVATION_LAYERS = Registry('activation layer')
+PADDING_LAYERS = Registry('padding layer')
+UPSAMPLE_LAYERS = Registry('upsample layer')
+PLUGIN_LAYERS = Registry('plugin layer')
+
+DROPOUT_LAYERS = Registry('drop out layers')
+POSITIONAL_ENCODING = Registry('position encoding')
+ATTENTION = Registry('attention')
+FEEDFORWARD_NETWORK = Registry('feed-forward Network')
+TRANSFORMER_LAYER = Registry('transformerLayer')
+TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
diff --git a/annotator/uniformer/mmcv/cnn/bricks/scale.py b/annotator/uniformer/mmcv/cnn/bricks/scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..c905fffcc8bf998d18d94f927591963c428025e2
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/scale.py
@@ -0,0 +1,21 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+class Scale(nn.Module):
+    """A learnable scale parameter.
+
+    This layer scales the input by a learnable factor. It multiplies a
+    learnable scale parameter of shape (1,) with input of any shape.
+
+    Args:
+        scale (float): Initial value of scale factor. Default: 1.0
+    """
+
+    def __init__(self, scale=1.0):
+        super(Scale, self).__init__()
+        self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
+
+    def forward(self, x):
+        return x * self.scale
diff --git a/annotator/uniformer/mmcv/cnn/bricks/swish.py b/annotator/uniformer/mmcv/cnn/bricks/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ca8ed7b749413f011ae54aac0cab27e6f0b51f
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/swish.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from .registry import ACTIVATION_LAYERS
+
+
+@ACTIVATION_LAYERS.register_module()
+class Swish(nn.Module):
+    """Swish Module.
+
+    This module applies the swish function:
+
+    .. math::
+        Swish(x) = x * Sigmoid(x)
+
+    Returns:
+        Tensor: The output tensor.
+    """
+
+    def __init__(self):
+        super(Swish, self).__init__()
+
+    def forward(self, x):
+        return x * torch.sigmoid(x)
diff --git a/annotator/uniformer/mmcv/cnn/bricks/transformer.py b/annotator/uniformer/mmcv/cnn/bricks/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e61ae0dd941a7be00b3e41a3de833ec50470a45f
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/transformer.py
@@ -0,0 +1,595 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+
+from annotator.uniformer.mmcv import ConfigDict, deprecated_api_warning
+from annotator.uniformer.mmcv.cnn import Linear, build_activation_layer, build_norm_layer
+from annotator.uniformer.mmcv.runner.base_module import BaseModule, ModuleList, Sequential
+from annotator.uniformer.mmcv.utils import build_from_cfg
+from .drop import build_dropout
+from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
+                       TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
+
+# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
+try:
+    from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention  # noqa F401
+    warnings.warn(
+        ImportWarning(
+            '``MultiScaleDeformableAttention`` has been moved to '
+            '``mmcv.ops.multi_scale_deform_attn``, please change original path '  # noqa E501
+            '``from annotator.uniformer.mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` '  # noqa E501
+            'to ``from annotator.uniformer.mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` '  # noqa E501
+        ))
+
+except ImportError:
+    warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
+                  '``mmcv.ops.multi_scale_deform_attn``, '
+                  'You should install ``mmcv-full`` if you need this module. ')
+
+
+def build_positional_encoding(cfg, default_args=None):
+    """Builder for Position Encoding."""
+    return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
+
+
+def build_attention(cfg, default_args=None):
+    """Builder for attention."""
+    return build_from_cfg(cfg, ATTENTION, default_args)
+
+
+def build_feedforward_network(cfg, default_args=None):
+    """Builder for feed-forward network (FFN)."""
+    return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
+
+
+def build_transformer_layer(cfg, default_args=None):
+    """Builder for transformer layer."""
+    return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
+
+
+def build_transformer_layer_sequence(cfg, default_args=None):
+    """Builder for transformer encoder and transformer decoder."""
+    return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
+
+
+@ATTENTION.register_module()
+class MultiheadAttention(BaseModule):
+    """A wrapper for ``torch.nn.MultiheadAttention``.
+
+    This module implements MultiheadAttention with identity connection,
+    and positional encoding  is also passed as input.
+
+    Args:
+        embed_dims (int): The embedding dimension.
+        num_heads (int): Parallel attention heads.
+        attn_drop (float): A Dropout layer on attn_output_weights.
+            Default: 0.0.
+        proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+            Default: 0.0.
+        dropout_layer (obj:`ConfigDict`): The dropout_layer used
+            when adding the shortcut.
+        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+            Default: None.
+        batch_first (bool): When it is True,  Key, Query and Value are shape of
+            (batch, n, embed_dim), otherwise (n, batch, embed_dim).
+             Default to False.
+    """
+
+    def __init__(self,
+                 embed_dims,
+                 num_heads,
+                 attn_drop=0.,
+                 proj_drop=0.,
+                 dropout_layer=dict(type='Dropout', drop_prob=0.),
+                 init_cfg=None,
+                 batch_first=False,
+                 **kwargs):
+        super(MultiheadAttention, self).__init__(init_cfg)
+        if 'dropout' in kwargs:
+            warnings.warn('The arguments `dropout` in MultiheadAttention '
+                          'has been deprecated, now you can separately '
+                          'set `attn_drop`(float), proj_drop(float), '
+                          'and `dropout_layer`(dict) ')
+            attn_drop = kwargs['dropout']
+            dropout_layer['drop_prob'] = kwargs.pop('dropout')
+
+        self.embed_dims = embed_dims
+        self.num_heads = num_heads
+        self.batch_first = batch_first
+
+        self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
+                                          **kwargs)
+
+        self.proj_drop = nn.Dropout(proj_drop)
+        self.dropout_layer = build_dropout(
+            dropout_layer) if dropout_layer else nn.Identity()
+
+    @deprecated_api_warning({'residual': 'identity'},
+                            cls_name='MultiheadAttention')
+    def forward(self,
+                query,
+                key=None,
+                value=None,
+                identity=None,
+                query_pos=None,
+                key_pos=None,
+                attn_mask=None,
+                key_padding_mask=None,
+                **kwargs):
+        """Forward function for `MultiheadAttention`.
+
+        **kwargs allow passing a more general data flow when combining
+        with other operations in `transformerlayer`.
+
+        Args:
+            query (Tensor): The input query with shape [num_queries, bs,
+                embed_dims] if self.batch_first is False, else
+                [bs, num_queries embed_dims].
+            key (Tensor): The key tensor with shape [num_keys, bs,
+                embed_dims] if self.batch_first is False, else
+                [bs, num_keys, embed_dims] .
+                If None, the ``query`` will be used. Defaults to None.
+            value (Tensor): The value tensor with same shape as `key`.
+                Same in `nn.MultiheadAttention.forward`. Defaults to None.
+                If None, the `key` will be used.
+            identity (Tensor): This tensor, with the same shape as x,
+                will be used for the identity link.
+                If None, `x` will be used. Defaults to None.
+            query_pos (Tensor): The positional encoding for query, with
+                the same shape as `x`. If not None, it will
+                be added to `x` before forward function. Defaults to None.
+            key_pos (Tensor): The positional encoding for `key`, with the
+                same shape as `key`. Defaults to None. If not None, it will
+                be added to `key` before forward function. If None, and
+                `query_pos` has the same shape as `key`, then `query_pos`
+                will be used for `key_pos`. Defaults to None.
+            attn_mask (Tensor): ByteTensor mask with shape [num_queries,
+                num_keys]. Same in `nn.MultiheadAttention.forward`.
+                Defaults to None.
+            key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
+                Defaults to None.
+
+        Returns:
+            Tensor: forwarded results with shape
+                [num_queries, bs, embed_dims]
+                if self.batch_first is False, else
+                [bs, num_queries embed_dims].
+        """
+
+        if key is None:
+            key = query
+        if value is None:
+            value = key
+        if identity is None:
+            identity = query
+        if key_pos is None:
+            if query_pos is not None:
+                # use query_pos if key_pos is not available
+                if query_pos.shape == key.shape:
+                    key_pos = query_pos
+                else:
+                    warnings.warn(f'position encoding of key is'
+                                  f'missing in {self.__class__.__name__}.')
+        if query_pos is not None:
+            query = query + query_pos
+        if key_pos is not None:
+            key = key + key_pos
+
+        # Because the dataflow('key', 'query', 'value') of
+        # ``torch.nn.MultiheadAttention`` is (num_query, batch,
+        # embed_dims), We should adjust the shape of dataflow from
+        # batch_first (batch, num_query, embed_dims) to num_query_first
+        # (num_query ,batch, embed_dims), and recover ``attn_output``
+        # from num_query_first to batch_first.
+        if self.batch_first:
+            query = query.transpose(0, 1)
+            key = key.transpose(0, 1)
+            value = value.transpose(0, 1)
+
+        out = self.attn(
+            query=query,
+            key=key,
+            value=value,
+            attn_mask=attn_mask,
+            key_padding_mask=key_padding_mask)[0]
+
+        if self.batch_first:
+            out = out.transpose(0, 1)
+
+        return identity + self.dropout_layer(self.proj_drop(out))
+
+
+@FEEDFORWARD_NETWORK.register_module()
+class FFN(BaseModule):
+    """Implements feed-forward networks (FFNs) with identity connection.
+
+    Args:
+        embed_dims (int): The feature dimension. Same as
+            `MultiheadAttention`. Defaults: 256.
+        feedforward_channels (int): The hidden dimension of FFNs.
+            Defaults: 1024.
+        num_fcs (int, optional): The number of fully-connected layers in
+            FFNs. Default: 2.
+        act_cfg (dict, optional): The activation config for FFNs.
+            Default: dict(type='ReLU')
+        ffn_drop (float, optional): Probability of an element to be
+            zeroed in FFN. Default 0.0.
+        add_identity (bool, optional): Whether to add the
+            identity connection. Default: `True`.
+        dropout_layer (obj:`ConfigDict`): The dropout_layer used
+            when adding the shortcut.
+        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+            Default: None.
+    """
+
+    @deprecated_api_warning(
+        {
+            'dropout': 'ffn_drop',
+            'add_residual': 'add_identity'
+        },
+        cls_name='FFN')
+    def __init__(self,
+                 embed_dims=256,
+                 feedforward_channels=1024,
+                 num_fcs=2,
+                 act_cfg=dict(type='ReLU', inplace=True),
+                 ffn_drop=0.,
+                 dropout_layer=None,
+                 add_identity=True,
+                 init_cfg=None,
+                 **kwargs):
+        super(FFN, self).__init__(init_cfg)
+        assert num_fcs >= 2, 'num_fcs should be no less ' \
+            f'than 2. got {num_fcs}.'
+        self.embed_dims = embed_dims
+        self.feedforward_channels = feedforward_channels
+        self.num_fcs = num_fcs
+        self.act_cfg = act_cfg
+        self.activate = build_activation_layer(act_cfg)
+
+        layers = []
+        in_channels = embed_dims
+        for _ in range(num_fcs - 1):
+            layers.append(
+                Sequential(
+                    Linear(in_channels, feedforward_channels), self.activate,
+                    nn.Dropout(ffn_drop)))
+            in_channels = feedforward_channels
+        layers.append(Linear(feedforward_channels, embed_dims))
+        layers.append(nn.Dropout(ffn_drop))
+        self.layers = Sequential(*layers)
+        self.dropout_layer = build_dropout(
+            dropout_layer) if dropout_layer else torch.nn.Identity()
+        self.add_identity = add_identity
+
+    @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
+    def forward(self, x, identity=None):
+        """Forward function for `FFN`.
+
+        The function would add x to the output tensor if residue is None.
+        """
+        out = self.layers(x)
+        if not self.add_identity:
+            return self.dropout_layer(out)
+        if identity is None:
+            identity = x
+        return identity + self.dropout_layer(out)
+
+
+@TRANSFORMER_LAYER.register_module()
+class BaseTransformerLayer(BaseModule):
+    """Base `TransformerLayer` for vision transformer.
+
+    It can be built from `mmcv.ConfigDict` and support more flexible
+    customization, for example, using any number of `FFN or LN ` and
+    use different kinds of `attention` by specifying a list of `ConfigDict`
+    named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
+    when you specifying `norm` as the first element of `operation_order`.
+    More details about the `prenorm`: `On Layer Normalization in the
+    Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
+
+    Args:
+        attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+            Configs for `self_attention` or `cross_attention` modules,
+            The order of the configs in the list should be consistent with
+            corresponding attentions in operation_order.
+            If it is a dict, all of the attention modules in operation_order
+            will be built with this config. Default: None.
+        ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
+            Configs for FFN, The order of the configs in the list should be
+            consistent with corresponding ffn in operation_order.
+            If it is a dict, all of the attention modules in operation_order
+            will be built with this config.
+        operation_order (tuple[str]): The execution order of operation
+            in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
+            Support `prenorm` when you specifying first element as `norm`.
+            Default:None.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='LN').
+        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+            Default: None.
+        batch_first (bool): Key, Query and Value are shape
+            of (batch, n, embed_dim)
+            or (n, batch, embed_dim). Default to False.
+    """
+
+    def __init__(self,
+                 attn_cfgs=None,
+                 ffn_cfgs=dict(
+                     type='FFN',
+                     embed_dims=256,
+                     feedforward_channels=1024,
+                     num_fcs=2,
+                     ffn_drop=0.,
+                     act_cfg=dict(type='ReLU', inplace=True),
+                 ),
+                 operation_order=None,
+                 norm_cfg=dict(type='LN'),
+                 init_cfg=None,
+                 batch_first=False,
+                 **kwargs):
+
+        deprecated_args = dict(
+            feedforward_channels='feedforward_channels',
+            ffn_dropout='ffn_drop',
+            ffn_num_fcs='num_fcs')
+        for ori_name, new_name in deprecated_args.items():
+            if ori_name in kwargs:
+                warnings.warn(
+                    f'The arguments `{ori_name}` in BaseTransformerLayer '
+                    f'has been deprecated, now you should set `{new_name}` '
+                    f'and other FFN related arguments '
+                    f'to a dict named `ffn_cfgs`. ')
+                ffn_cfgs[new_name] = kwargs[ori_name]
+
+        super(BaseTransformerLayer, self).__init__(init_cfg)
+
+        self.batch_first = batch_first
+
+        assert set(operation_order) & set(
+            ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
+            set(operation_order), f'The operation_order of' \
+            f' {self.__class__.__name__} should ' \
+            f'contains all four operation type ' \
+            f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
+
+        num_attn = operation_order.count('self_attn') + operation_order.count(
+            'cross_attn')
+        if isinstance(attn_cfgs, dict):
+            attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
+        else:
+            assert num_attn == len(attn_cfgs), f'The length ' \
+                f'of attn_cfg {num_attn} is ' \
+                f'not consistent with the number of attention' \
+                f'in operation_order {operation_order}.'
+
+        self.num_attn = num_attn
+        self.operation_order = operation_order
+        self.norm_cfg = norm_cfg
+        self.pre_norm = operation_order[0] == 'norm'
+        self.attentions = ModuleList()
+
+        index = 0
+        for operation_name in operation_order:
+            if operation_name in ['self_attn', 'cross_attn']:
+                if 'batch_first' in attn_cfgs[index]:
+                    assert self.batch_first == attn_cfgs[index]['batch_first']
+                else:
+                    attn_cfgs[index]['batch_first'] = self.batch_first
+                attention = build_attention(attn_cfgs[index])
+                # Some custom attentions used as `self_attn`
+                # or `cross_attn` can have different behavior.
+                attention.operation_name = operation_name
+                self.attentions.append(attention)
+                index += 1
+
+        self.embed_dims = self.attentions[0].embed_dims
+
+        self.ffns = ModuleList()
+        num_ffns = operation_order.count('ffn')
+        if isinstance(ffn_cfgs, dict):
+            ffn_cfgs = ConfigDict(ffn_cfgs)
+        if isinstance(ffn_cfgs, dict):
+            ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
+        assert len(ffn_cfgs) == num_ffns
+        for ffn_index in range(num_ffns):
+            if 'embed_dims' not in ffn_cfgs[ffn_index]:
+                ffn_cfgs['embed_dims'] = self.embed_dims
+            else:
+                assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
+            self.ffns.append(
+                build_feedforward_network(ffn_cfgs[ffn_index],
+                                          dict(type='FFN')))
+
+        self.norms = ModuleList()
+        num_norms = operation_order.count('norm')
+        for _ in range(num_norms):
+            self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
+
+    def forward(self,
+                query,
+                key=None,
+                value=None,
+                query_pos=None,
+                key_pos=None,
+                attn_masks=None,
+                query_key_padding_mask=None,
+                key_padding_mask=None,
+                **kwargs):
+        """Forward function for `TransformerDecoderLayer`.
+
+        **kwargs contains some specific arguments of attentions.
+
+        Args:
+            query (Tensor): The input query with shape
+                [num_queries, bs, embed_dims] if
+                self.batch_first is False, else
+                [bs, num_queries embed_dims].
+            key (Tensor): The key tensor with shape [num_keys, bs,
+                embed_dims] if self.batch_first is False, else
+                [bs, num_keys, embed_dims] .
+            value (Tensor): The value tensor with same shape as `key`.
+            query_pos (Tensor): The positional encoding for `query`.
+                Default: None.
+            key_pos (Tensor): The positional encoding for `key`.
+                Default: None.
+            attn_masks (List[Tensor] | None): 2D Tensor used in
+                calculation of corresponding attention. The length of
+                it should equal to the number of `attention` in
+                `operation_order`. Default: None.
+            query_key_padding_mask (Tensor): ByteTensor for `query`, with
+                shape [bs, num_queries]. Only used in `self_attn` layer.
+                Defaults to None.
+            key_padding_mask (Tensor): ByteTensor for `query`, with
+                shape [bs, num_keys]. Default: None.
+
+        Returns:
+            Tensor: forwarded results with shape [num_queries, bs, embed_dims].
+        """
+
+        norm_index = 0
+        attn_index = 0
+        ffn_index = 0
+        identity = query
+        if attn_masks is None:
+            attn_masks = [None for _ in range(self.num_attn)]
+        elif isinstance(attn_masks, torch.Tensor):
+            attn_masks = [
+                copy.deepcopy(attn_masks) for _ in range(self.num_attn)
+            ]
+            warnings.warn(f'Use same attn_mask in all attentions in '
+                          f'{self.__class__.__name__} ')
+        else:
+            assert len(attn_masks) == self.num_attn, f'The length of ' \
+                        f'attn_masks {len(attn_masks)} must be equal ' \
+                        f'to the number of attention in ' \
+                        f'operation_order {self.num_attn}'
+
+        for layer in self.operation_order:
+            if layer == 'self_attn':
+                temp_key = temp_value = query
+                query = self.attentions[attn_index](
+                    query,
+                    temp_key,
+                    temp_value,
+                    identity if self.pre_norm else None,
+                    query_pos=query_pos,
+                    key_pos=query_pos,
+                    attn_mask=attn_masks[attn_index],
+                    key_padding_mask=query_key_padding_mask,
+                    **kwargs)
+                attn_index += 1
+                identity = query
+
+            elif layer == 'norm':
+                query = self.norms[norm_index](query)
+                norm_index += 1
+
+            elif layer == 'cross_attn':
+                query = self.attentions[attn_index](
+                    query,
+                    key,
+                    value,
+                    identity if self.pre_norm else None,
+                    query_pos=query_pos,
+                    key_pos=key_pos,
+                    attn_mask=attn_masks[attn_index],
+                    key_padding_mask=key_padding_mask,
+                    **kwargs)
+                attn_index += 1
+                identity = query
+
+            elif layer == 'ffn':
+                query = self.ffns[ffn_index](
+                    query, identity if self.pre_norm else None)
+                ffn_index += 1
+
+        return query
+
+
+@TRANSFORMER_LAYER_SEQUENCE.register_module()
+class TransformerLayerSequence(BaseModule):
+    """Base class for TransformerEncoder and TransformerDecoder in vision
+    transformer.
+
+    As base-class of Encoder and Decoder in vision transformer.
+    Support customization such as specifying different kind
+    of `transformer_layer` in `transformer_coder`.
+
+    Args:
+        transformerlayer (list[obj:`mmcv.ConfigDict`] |
+            obj:`mmcv.ConfigDict`): Config of transformerlayer
+            in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
+             it would be repeated `num_layer` times to a
+             list[`mmcv.ConfigDict`]. Default: None.
+        num_layers (int): The number of `TransformerLayer`. Default: None.
+        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+            Default: None.
+    """
+
+    def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
+        super(TransformerLayerSequence, self).__init__(init_cfg)
+        if isinstance(transformerlayers, dict):
+            transformerlayers = [
+                copy.deepcopy(transformerlayers) for _ in range(num_layers)
+            ]
+        else:
+            assert isinstance(transformerlayers, list) and \
+                   len(transformerlayers) == num_layers
+        self.num_layers = num_layers
+        self.layers = ModuleList()
+        for i in range(num_layers):
+            self.layers.append(build_transformer_layer(transformerlayers[i]))
+        self.embed_dims = self.layers[0].embed_dims
+        self.pre_norm = self.layers[0].pre_norm
+
+    def forward(self,
+                query,
+                key,
+                value,
+                query_pos=None,
+                key_pos=None,
+                attn_masks=None,
+                query_key_padding_mask=None,
+                key_padding_mask=None,
+                **kwargs):
+        """Forward function for `TransformerCoder`.
+
+        Args:
+            query (Tensor): Input query with shape
+                `(num_queries, bs, embed_dims)`.
+            key (Tensor): The key tensor with shape
+                `(num_keys, bs, embed_dims)`.
+            value (Tensor): The value tensor with shape
+                `(num_keys, bs, embed_dims)`.
+            query_pos (Tensor): The positional encoding for `query`.
+                Default: None.
+            key_pos (Tensor): The positional encoding for `key`.
+                Default: None.
+            attn_masks (List[Tensor], optional): Each element is 2D Tensor
+                which is used in calculation of corresponding attention in
+                operation_order. Default: None.
+            query_key_padding_mask (Tensor): ByteTensor for `query`, with
+                shape [bs, num_queries]. Only used in self-attention
+                Default: None.
+            key_padding_mask (Tensor): ByteTensor for `query`, with
+                shape [bs, num_keys]. Default: None.
+
+        Returns:
+            Tensor:  results with shape [num_queries, bs, embed_dims].
+        """
+        for layer in self.layers:
+            query = layer(
+                query,
+                key,
+                value,
+                query_pos=query_pos,
+                key_pos=key_pos,
+                attn_masks=attn_masks,
+                query_key_padding_mask=query_key_padding_mask,
+                key_padding_mask=key_padding_mask,
+                **kwargs)
+        return query
diff --git a/annotator/uniformer/mmcv/cnn/bricks/upsample.py b/annotator/uniformer/mmcv/cnn/bricks/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a353767d0ce8518f0d7289bed10dba0178ed12
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/upsample.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..utils import xavier_init
+from .registry import UPSAMPLE_LAYERS
+
+UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
+UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
+
+
+@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
+class PixelShufflePack(nn.Module):
+    """Pixel Shuffle upsample layer.
+
+    This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
+    achieve a simple upsampling with pixel shuffle.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        scale_factor (int): Upsample ratio.
+        upsample_kernel (int): Kernel size of the conv layer to expand the
+            channels.
+    """
+
+    def __init__(self, in_channels, out_channels, scale_factor,
+                 upsample_kernel):
+        super(PixelShufflePack, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.scale_factor = scale_factor
+        self.upsample_kernel = upsample_kernel
+        self.upsample_conv = nn.Conv2d(
+            self.in_channels,
+            self.out_channels * scale_factor * scale_factor,
+            self.upsample_kernel,
+            padding=(self.upsample_kernel - 1) // 2)
+        self.init_weights()
+
+    def init_weights(self):
+        xavier_init(self.upsample_conv, distribution='uniform')
+
+    def forward(self, x):
+        x = self.upsample_conv(x)
+        x = F.pixel_shuffle(x, self.scale_factor)
+        return x
+
+
+def build_upsample_layer(cfg, *args, **kwargs):
+    """Build upsample layer.
+
+    Args:
+        cfg (dict): The upsample layer config, which should contain:
+
+            - type (str): Layer type.
+            - scale_factor (int): Upsample ratio, which is not applicable to
+                deconv.
+            - layer args: Args needed to instantiate a upsample layer.
+        args (argument list): Arguments passed to the ``__init__``
+            method of the corresponding conv layer.
+        kwargs (keyword arguments): Keyword arguments passed to the
+            ``__init__`` method of the corresponding conv layer.
+
+    Returns:
+        nn.Module: Created upsample layer.
+    """
+    if not isinstance(cfg, dict):
+        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+    if 'type' not in cfg:
+        raise KeyError(
+            f'the cfg dict must contain the key "type", but got {cfg}')
+    cfg_ = cfg.copy()
+
+    layer_type = cfg_.pop('type')
+    if layer_type not in UPSAMPLE_LAYERS:
+        raise KeyError(f'Unrecognized upsample type {layer_type}')
+    else:
+        upsample = UPSAMPLE_LAYERS.get(layer_type)
+
+    if upsample is nn.Upsample:
+        cfg_['mode'] = layer_type
+    layer = upsample(*args, **kwargs, **cfg_)
+    return layer
diff --git a/annotator/uniformer/mmcv/cnn/bricks/wrappers.py b/annotator/uniformer/mmcv/cnn/bricks/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aebf67bf52355a513f21756ee74fe510902d075
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/bricks/wrappers.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py  # noqa: E501
+
+Wrap some nn modules to support empty tensor input. Currently, these wrappers
+are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
+heads are trained on only positive RoIs.
+"""
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair, _triple
+
+from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
+
+if torch.__version__ == 'parrots':
+    TORCH_VERSION = torch.__version__
+else:
+    # torch.__version__ could be 1.3.1+cu92, we only need the first two
+    # for comparison
+    TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
+
+
+def obsolete_torch_version(torch_version, version_threshold):
+    return torch_version == 'parrots' or torch_version <= version_threshold
+
+
+class NewEmptyTensorOp(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, x, new_shape):
+        ctx.shape = x.shape
+        return x.new_empty(new_shape)
+
+    @staticmethod
+    def backward(ctx, grad):
+        shape = ctx.shape
+        return NewEmptyTensorOp.apply(grad, shape), None
+
+
+@CONV_LAYERS.register_module('Conv', force=True)
+class Conv2d(nn.Conv2d):
+
+    def forward(self, x):
+        if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+            out_shape = [x.shape[0], self.out_channels]
+            for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
+                                     self.padding, self.stride, self.dilation):
+                o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+                out_shape.append(o)
+            empty = NewEmptyTensorOp.apply(x, out_shape)
+            if self.training:
+                # produce dummy gradient to avoid DDP warning.
+                dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+                return empty + dummy
+            else:
+                return empty
+
+        return super().forward(x)
+
+
+@CONV_LAYERS.register_module('Conv3d', force=True)
+class Conv3d(nn.Conv3d):
+
+    def forward(self, x):
+        if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+            out_shape = [x.shape[0], self.out_channels]
+            for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
+                                     self.padding, self.stride, self.dilation):
+                o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
+                out_shape.append(o)
+            empty = NewEmptyTensorOp.apply(x, out_shape)
+            if self.training:
+                # produce dummy gradient to avoid DDP warning.
+                dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+                return empty + dummy
+            else:
+                return empty
+
+        return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv')
+@UPSAMPLE_LAYERS.register_module('deconv', force=True)
+class ConvTranspose2d(nn.ConvTranspose2d):
+
+    def forward(self, x):
+        if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+            out_shape = [x.shape[0], self.out_channels]
+            for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
+                                         self.padding, self.stride,
+                                         self.dilation, self.output_padding):
+                out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+            empty = NewEmptyTensorOp.apply(x, out_shape)
+            if self.training:
+                # produce dummy gradient to avoid DDP warning.
+                dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+                return empty + dummy
+            else:
+                return empty
+
+        return super().forward(x)
+
+
+@CONV_LAYERS.register_module()
+@CONV_LAYERS.register_module('deconv3d')
+@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
+class ConvTranspose3d(nn.ConvTranspose3d):
+
+    def forward(self, x):
+        if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
+            out_shape = [x.shape[0], self.out_channels]
+            for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
+                                         self.padding, self.stride,
+                                         self.dilation, self.output_padding):
+                out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
+            empty = NewEmptyTensorOp.apply(x, out_shape)
+            if self.training:
+                # produce dummy gradient to avoid DDP warning.
+                dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+                return empty + dummy
+            else:
+                return empty
+
+        return super().forward(x)
+
+
+class MaxPool2d(nn.MaxPool2d):
+
+    def forward(self, x):
+        # PyTorch 1.9 does not support empty tensor inference yet
+        if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+            out_shape = list(x.shape[:2])
+            for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
+                                     _pair(self.padding), _pair(self.stride),
+                                     _pair(self.dilation)):
+                o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+                o = math.ceil(o) if self.ceil_mode else math.floor(o)
+                out_shape.append(o)
+            empty = NewEmptyTensorOp.apply(x, out_shape)
+            return empty
+
+        return super().forward(x)
+
+
+class MaxPool3d(nn.MaxPool3d):
+
+    def forward(self, x):
+        # PyTorch 1.9 does not support empty tensor inference yet
+        if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
+            out_shape = list(x.shape[:2])
+            for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
+                                     _triple(self.padding),
+                                     _triple(self.stride),
+                                     _triple(self.dilation)):
+                o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
+                o = math.ceil(o) if self.ceil_mode else math.floor(o)
+                out_shape.append(o)
+            empty = NewEmptyTensorOp.apply(x, out_shape)
+            return empty
+
+        return super().forward(x)
+
+
+class Linear(torch.nn.Linear):
+
+    def forward(self, x):
+        # empty tensor forward of Linear layer is supported in Pytorch 1.6
+        if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
+            out_shape = [x.shape[0], self.out_features]
+            empty = NewEmptyTensorOp.apply(x, out_shape)
+            if self.training:
+                # produce dummy gradient to avoid DDP warning.
+                dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
+                return empty + dummy
+            else:
+                return empty
+
+        return super().forward(x)
diff --git a/annotator/uniformer/mmcv/cnn/builder.py b/annotator/uniformer/mmcv/cnn/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..7567316c566bd3aca6d8f65a84b00e9e890948a7
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/builder.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..runner import Sequential
+from ..utils import Registry, build_from_cfg
+
+
+def build_model_from_cfg(cfg, registry, default_args=None):
+    """Build a PyTorch model from config dict(s). Different from
+    ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
+
+    Args:
+        cfg (dict, list[dict]): The config of modules, is is either a config
+            dict or a list of config dicts. If cfg is a list, a
+            the built modules will be wrapped with ``nn.Sequential``.
+        registry (:obj:`Registry`): A registry the module belongs to.
+        default_args (dict, optional): Default arguments to build the module.
+            Defaults to None.
+
+    Returns:
+        nn.Module: A built nn module.
+    """
+    if isinstance(cfg, list):
+        modules = [
+            build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+        ]
+        return Sequential(*modules)
+    else:
+        return build_from_cfg(cfg, registry, default_args)
+
+
+MODELS = Registry('model', build_func=build_model_from_cfg)
diff --git a/annotator/uniformer/mmcv/cnn/resnet.py b/annotator/uniformer/mmcv/cnn/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb3ac057ee2d52c46fc94685b5d4e698aad8d5f
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/resnet.py
@@ -0,0 +1,316 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+
+from .utils import constant_init, kaiming_init
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+    """3x3 convolution with padding."""
+    return nn.Conv2d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=dilation,
+        dilation=dilation,
+        bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 stride=1,
+                 dilation=1,
+                 downsample=None,
+                 style='pytorch',
+                 with_cp=False):
+        super(BasicBlock, self).__init__()
+        assert style in ['pytorch', 'caffe']
+        self.conv1 = conv3x3(inplanes, planes, stride, dilation)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+        self.dilation = dilation
+        assert not with_cp
+
+    def forward(self, x):
+        residual = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            residual = self.downsample(x)
+
+        out += residual
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 stride=1,
+                 dilation=1,
+                 downsample=None,
+                 style='pytorch',
+                 with_cp=False):
+        """Bottleneck block.
+
+        If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+        it is "caffe", the stride-two layer is the first 1x1 conv layer.
+        """
+        super(Bottleneck, self).__init__()
+        assert style in ['pytorch', 'caffe']
+        if style == 'pytorch':
+            conv1_stride = 1
+            conv2_stride = stride
+        else:
+            conv1_stride = stride
+            conv2_stride = 1
+        self.conv1 = nn.Conv2d(
+            inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False)
+        self.conv2 = nn.Conv2d(
+            planes,
+            planes,
+            kernel_size=3,
+            stride=conv2_stride,
+            padding=dilation,
+            dilation=dilation,
+            bias=False)
+
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = nn.Conv2d(
+            planes, planes * self.expansion, kernel_size=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+        self.dilation = dilation
+        self.with_cp = with_cp
+
+    def forward(self, x):
+
+        def _inner_forward(x):
+            residual = x
+
+            out = self.conv1(x)
+            out = self.bn1(out)
+            out = self.relu(out)
+
+            out = self.conv2(out)
+            out = self.bn2(out)
+            out = self.relu(out)
+
+            out = self.conv3(out)
+            out = self.bn3(out)
+
+            if self.downsample is not None:
+                residual = self.downsample(x)
+
+            out += residual
+
+            return out
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        out = self.relu(out)
+
+        return out
+
+
+def make_res_layer(block,
+                   inplanes,
+                   planes,
+                   blocks,
+                   stride=1,
+                   dilation=1,
+                   style='pytorch',
+                   with_cp=False):
+    downsample = None
+    if stride != 1 or inplanes != planes * block.expansion:
+        downsample = nn.Sequential(
+            nn.Conv2d(
+                inplanes,
+                planes * block.expansion,
+                kernel_size=1,
+                stride=stride,
+                bias=False),
+            nn.BatchNorm2d(planes * block.expansion),
+        )
+
+    layers = []
+    layers.append(
+        block(
+            inplanes,
+            planes,
+            stride,
+            dilation,
+            downsample,
+            style=style,
+            with_cp=with_cp))
+    inplanes = planes * block.expansion
+    for _ in range(1, blocks):
+        layers.append(
+            block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp))
+
+    return nn.Sequential(*layers)
+
+
+class ResNet(nn.Module):
+    """ResNet backbone.
+
+    Args:
+        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+        num_stages (int): Resnet stages, normally 4.
+        strides (Sequence[int]): Strides of the first block of each stage.
+        dilations (Sequence[int]): Dilation of each stage.
+        out_indices (Sequence[int]): Output from which stages.
+        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+            layer is the 3x3 conv layer, otherwise the stride-two layer is
+            the first 1x1 conv layer.
+        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+            not freezing any parameters.
+        bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+            running stats (mean and var).
+        bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed.
+    """
+
+    arch_settings = {
+        18: (BasicBlock, (2, 2, 2, 2)),
+        34: (BasicBlock, (3, 4, 6, 3)),
+        50: (Bottleneck, (3, 4, 6, 3)),
+        101: (Bottleneck, (3, 4, 23, 3)),
+        152: (Bottleneck, (3, 8, 36, 3))
+    }
+
+    def __init__(self,
+                 depth,
+                 num_stages=4,
+                 strides=(1, 2, 2, 2),
+                 dilations=(1, 1, 1, 1),
+                 out_indices=(0, 1, 2, 3),
+                 style='pytorch',
+                 frozen_stages=-1,
+                 bn_eval=True,
+                 bn_frozen=False,
+                 with_cp=False):
+        super(ResNet, self).__init__()
+        if depth not in self.arch_settings:
+            raise KeyError(f'invalid depth {depth} for resnet')
+        assert num_stages >= 1 and num_stages <= 4
+        block, stage_blocks = self.arch_settings[depth]
+        stage_blocks = stage_blocks[:num_stages]
+        assert len(strides) == len(dilations) == num_stages
+        assert max(out_indices) < num_stages
+
+        self.out_indices = out_indices
+        self.style = style
+        self.frozen_stages = frozen_stages
+        self.bn_eval = bn_eval
+        self.bn_frozen = bn_frozen
+        self.with_cp = with_cp
+
+        self.inplanes = 64
+        self.conv1 = nn.Conv2d(
+            3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+        self.res_layers = []
+        for i, num_blocks in enumerate(stage_blocks):
+            stride = strides[i]
+            dilation = dilations[i]
+            planes = 64 * 2**i
+            res_layer = make_res_layer(
+                block,
+                self.inplanes,
+                planes,
+                num_blocks,
+                stride=stride,
+                dilation=dilation,
+                style=self.style,
+                with_cp=with_cp)
+            self.inplanes = planes * block.expansion
+            layer_name = f'layer{i + 1}'
+            self.add_module(layer_name, res_layer)
+            self.res_layers.append(layer_name)
+
+        self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1)
+
+    def init_weights(self, pretrained=None):
+        if isinstance(pretrained, str):
+            logger = logging.getLogger()
+            from ..runner import load_checkpoint
+            load_checkpoint(self, pretrained, strict=False, logger=logger)
+        elif pretrained is None:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d):
+                    kaiming_init(m)
+                elif isinstance(m, nn.BatchNorm2d):
+                    constant_init(m, 1)
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+        outs = []
+        for i, layer_name in enumerate(self.res_layers):
+            res_layer = getattr(self, layer_name)
+            x = res_layer(x)
+            if i in self.out_indices:
+                outs.append(x)
+        if len(outs) == 1:
+            return outs[0]
+        else:
+            return tuple(outs)
+
+    def train(self, mode=True):
+        super(ResNet, self).train(mode)
+        if self.bn_eval:
+            for m in self.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eval()
+                    if self.bn_frozen:
+                        for params in m.parameters():
+                            params.requires_grad = False
+        if mode and self.frozen_stages >= 0:
+            for param in self.conv1.parameters():
+                param.requires_grad = False
+            for param in self.bn1.parameters():
+                param.requires_grad = False
+            self.bn1.eval()
+            self.bn1.weight.requires_grad = False
+            self.bn1.bias.requires_grad = False
+            for i in range(1, self.frozen_stages + 1):
+                mod = getattr(self, f'layer{i}')
+                mod.eval()
+                for param in mod.parameters():
+                    param.requires_grad = False
diff --git a/annotator/uniformer/mmcv/cnn/utils/__init__.py b/annotator/uniformer/mmcv/cnn/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a263e31c1e3977712827ca229bbc04910b4e928e
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .flops_counter import get_model_complexity_info
+from .fuse_conv_bn import fuse_conv_bn
+from .sync_bn import revert_sync_batchnorm
+from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
+                          KaimingInit, NormalInit, PretrainedInit,
+                          TruncNormalInit, UniformInit, XavierInit,
+                          bias_init_with_prob, caffe2_xavier_init,
+                          constant_init, initialize, kaiming_init, normal_init,
+                          trunc_normal_init, uniform_init, xavier_init)
+
+__all__ = [
+    'get_model_complexity_info', 'bias_init_with_prob', 'caffe2_xavier_init',
+    'constant_init', 'kaiming_init', 'normal_init', 'trunc_normal_init',
+    'uniform_init', 'xavier_init', 'fuse_conv_bn', 'initialize',
+    'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
+    'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
+    'Caffe2XavierInit', 'revert_sync_batchnorm'
+]
diff --git a/annotator/uniformer/mmcv/cnn/utils/flops_counter.py b/annotator/uniformer/mmcv/cnn/utils/flops_counter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10af5feca7f4b8c0ba359b7b1c826f754e048be
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/flops_counter.py
@@ -0,0 +1,599 @@
+# Modified from flops-counter.pytorch by Vladislav Sovrasov
+# original repo: https://github.com/sovrasov/flops-counter.pytorch
+
+# MIT License
+
+# Copyright (c) 2018 Vladislav Sovrasov
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import sys
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+import annotator.uniformer.mmcv as mmcv
+
+
+def get_model_complexity_info(model,
+                              input_shape,
+                              print_per_layer_stat=True,
+                              as_strings=True,
+                              input_constructor=None,
+                              flush=False,
+                              ost=sys.stdout):
+    """Get complexity information of a model.
+
+    This method can calculate FLOPs and parameter counts of a model with
+    corresponding input shape. It can also print complexity information for
+    each layer in a model.
+
+    Supported layers are listed as below:
+        - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
+        - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
+            ``nn.ReLU6``.
+        - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
+            ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
+            ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
+            ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
+            ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
+        - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
+            ``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
+            ``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
+        - Linear: ``nn.Linear``.
+        - Deconvolution: ``nn.ConvTranspose2d``.
+        - Upsample: ``nn.Upsample``.
+
+    Args:
+        model (nn.Module): The model for complexity calculation.
+        input_shape (tuple): Input shape used for calculation.
+        print_per_layer_stat (bool): Whether to print complexity information
+            for each layer in a model. Default: True.
+        as_strings (bool): Output FLOPs and params counts in a string form.
+            Default: True.
+        input_constructor (None | callable): If specified, it takes a callable
+            method that generates input. otherwise, it will generate a random
+            tensor with input shape to calculate FLOPs. Default: None.
+        flush (bool): same as that in :func:`print`. Default: False.
+        ost (stream): same as ``file`` param in :func:`print`.
+            Default: sys.stdout.
+
+    Returns:
+        tuple[float | str]: If ``as_strings`` is set to True, it will return
+            FLOPs and parameter counts in a string format. otherwise, it will
+            return those in a float number format.
+    """
+    assert type(input_shape) is tuple
+    assert len(input_shape) >= 1
+    assert isinstance(model, nn.Module)
+    flops_model = add_flops_counting_methods(model)
+    flops_model.eval()
+    flops_model.start_flops_count()
+    if input_constructor:
+        input = input_constructor(input_shape)
+        _ = flops_model(**input)
+    else:
+        try:
+            batch = torch.ones(()).new_empty(
+                (1, *input_shape),
+                dtype=next(flops_model.parameters()).dtype,
+                device=next(flops_model.parameters()).device)
+        except StopIteration:
+            # Avoid StopIteration for models which have no parameters,
+            # like `nn.Relu()`, `nn.AvgPool2d`, etc.
+            batch = torch.ones(()).new_empty((1, *input_shape))
+
+        _ = flops_model(batch)
+
+    flops_count, params_count = flops_model.compute_average_flops_cost()
+    if print_per_layer_stat:
+        print_model_with_flops(
+            flops_model, flops_count, params_count, ost=ost, flush=flush)
+    flops_model.stop_flops_count()
+
+    if as_strings:
+        return flops_to_string(flops_count), params_to_string(params_count)
+
+    return flops_count, params_count
+
+
+def flops_to_string(flops, units='GFLOPs', precision=2):
+    """Convert FLOPs number into a string.
+
+    Note that Here we take a multiply-add counts as one FLOP.
+
+    Args:
+        flops (float): FLOPs number to be converted.
+        units (str | None): Converted FLOPs units. Options are None, 'GFLOPs',
+            'MFLOPs', 'KFLOPs', 'FLOPs'. If set to None, it will automatically
+            choose the most suitable unit for FLOPs. Default: 'GFLOPs'.
+        precision (int): Digit number after the decimal point. Default: 2.
+
+    Returns:
+        str: The converted FLOPs number with units.
+
+    Examples:
+        >>> flops_to_string(1e9)
+        '1.0 GFLOPs'
+        >>> flops_to_string(2e5, 'MFLOPs')
+        '0.2 MFLOPs'
+        >>> flops_to_string(3e-9, None)
+        '3e-09 FLOPs'
+    """
+    if units is None:
+        if flops // 10**9 > 0:
+            return str(round(flops / 10.**9, precision)) + ' GFLOPs'
+        elif flops // 10**6 > 0:
+            return str(round(flops / 10.**6, precision)) + ' MFLOPs'
+        elif flops // 10**3 > 0:
+            return str(round(flops / 10.**3, precision)) + ' KFLOPs'
+        else:
+            return str(flops) + ' FLOPs'
+    else:
+        if units == 'GFLOPs':
+            return str(round(flops / 10.**9, precision)) + ' ' + units
+        elif units == 'MFLOPs':
+            return str(round(flops / 10.**6, precision)) + ' ' + units
+        elif units == 'KFLOPs':
+            return str(round(flops / 10.**3, precision)) + ' ' + units
+        else:
+            return str(flops) + ' FLOPs'
+
+
+def params_to_string(num_params, units=None, precision=2):
+    """Convert parameter number into a string.
+
+    Args:
+        num_params (float): Parameter number to be converted.
+        units (str | None): Converted FLOPs units. Options are None, 'M',
+            'K' and ''. If set to None, it will automatically choose the most
+            suitable unit for Parameter number. Default: None.
+        precision (int): Digit number after the decimal point. Default: 2.
+
+    Returns:
+        str: The converted parameter number with units.
+
+    Examples:
+        >>> params_to_string(1e9)
+        '1000.0 M'
+        >>> params_to_string(2e5)
+        '200.0 k'
+        >>> params_to_string(3e-9)
+        '3e-09'
+    """
+    if units is None:
+        if num_params // 10**6 > 0:
+            return str(round(num_params / 10**6, precision)) + ' M'
+        elif num_params // 10**3:
+            return str(round(num_params / 10**3, precision)) + ' k'
+        else:
+            return str(num_params)
+    else:
+        if units == 'M':
+            return str(round(num_params / 10.**6, precision)) + ' ' + units
+        elif units == 'K':
+            return str(round(num_params / 10.**3, precision)) + ' ' + units
+        else:
+            return str(num_params)
+
+
+def print_model_with_flops(model,
+                           total_flops,
+                           total_params,
+                           units='GFLOPs',
+                           precision=3,
+                           ost=sys.stdout,
+                           flush=False):
+    """Print a model with FLOPs for each layer.
+
+    Args:
+        model (nn.Module): The model to be printed.
+        total_flops (float): Total FLOPs of the model.
+        total_params (float): Total parameter counts of the model.
+        units (str | None): Converted FLOPs units. Default: 'GFLOPs'.
+        precision (int): Digit number after the decimal point. Default: 3.
+        ost (stream): same as `file` param in :func:`print`.
+            Default: sys.stdout.
+        flush (bool): same as that in :func:`print`. Default: False.
+
+    Example:
+        >>> class ExampleModel(nn.Module):
+
+        >>> def __init__(self):
+        >>>     super().__init__()
+        >>>     self.conv1 = nn.Conv2d(3, 8, 3)
+        >>>     self.conv2 = nn.Conv2d(8, 256, 3)
+        >>>     self.conv3 = nn.Conv2d(256, 8, 3)
+        >>>     self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
+        >>>     self.flatten = nn.Flatten()
+        >>>     self.fc = nn.Linear(8, 1)
+
+        >>> def forward(self, x):
+        >>>     x = self.conv1(x)
+        >>>     x = self.conv2(x)
+        >>>     x = self.conv3(x)
+        >>>     x = self.avg_pool(x)
+        >>>     x = self.flatten(x)
+        >>>     x = self.fc(x)
+        >>>     return x
+
+        >>> model = ExampleModel()
+        >>> x = (3, 16, 16)
+        to print the complexity information state for each layer, you can use
+        >>> get_model_complexity_info(model, x)
+        or directly use
+        >>> print_model_with_flops(model, 4579784.0, 37361)
+        ExampleModel(
+          0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
+          (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1))  # noqa: E501
+          (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
+          (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
+          (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
+          (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
+          (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
+        )
+    """
+
+    def accumulate_params(self):
+        if is_supported_instance(self):
+            return self.__params__
+        else:
+            sum = 0
+            for m in self.children():
+                sum += m.accumulate_params()
+            return sum
+
+    def accumulate_flops(self):
+        if is_supported_instance(self):
+            return self.__flops__ / model.__batch_counter__
+        else:
+            sum = 0
+            for m in self.children():
+                sum += m.accumulate_flops()
+            return sum
+
+    def flops_repr(self):
+        accumulated_num_params = self.accumulate_params()
+        accumulated_flops_cost = self.accumulate_flops()
+        return ', '.join([
+            params_to_string(
+                accumulated_num_params, units='M', precision=precision),
+            '{:.3%} Params'.format(accumulated_num_params / total_params),
+            flops_to_string(
+                accumulated_flops_cost, units=units, precision=precision),
+            '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
+            self.original_extra_repr()
+        ])
+
+    def add_extra_repr(m):
+        m.accumulate_flops = accumulate_flops.__get__(m)
+        m.accumulate_params = accumulate_params.__get__(m)
+        flops_extra_repr = flops_repr.__get__(m)
+        if m.extra_repr != flops_extra_repr:
+            m.original_extra_repr = m.extra_repr
+            m.extra_repr = flops_extra_repr
+            assert m.extra_repr != m.original_extra_repr
+
+    def del_extra_repr(m):
+        if hasattr(m, 'original_extra_repr'):
+            m.extra_repr = m.original_extra_repr
+            del m.original_extra_repr
+        if hasattr(m, 'accumulate_flops'):
+            del m.accumulate_flops
+
+    model.apply(add_extra_repr)
+    print(model, file=ost, flush=flush)
+    model.apply(del_extra_repr)
+
+
+def get_model_parameters_number(model):
+    """Calculate parameter number of a model.
+
+    Args:
+        model (nn.module): The model for parameter number calculation.
+
+    Returns:
+        float: Parameter number of the model.
+    """
+    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    return num_params
+
+
+def add_flops_counting_methods(net_main_module):
+    # adding additional methods to the existing module object,
+    # this is done this way so that each function has access to self object
+    net_main_module.start_flops_count = start_flops_count.__get__(
+        net_main_module)
+    net_main_module.stop_flops_count = stop_flops_count.__get__(
+        net_main_module)
+    net_main_module.reset_flops_count = reset_flops_count.__get__(
+        net_main_module)
+    net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(  # noqa: E501
+        net_main_module)
+
+    net_main_module.reset_flops_count()
+
+    return net_main_module
+
+
+def compute_average_flops_cost(self):
+    """Compute average FLOPs cost.
+
+    A method to compute average FLOPs cost, which will be available after
+    `add_flops_counting_methods()` is called on a desired net object.
+
+    Returns:
+        float: Current mean flops consumption per image.
+    """
+    batches_count = self.__batch_counter__
+    flops_sum = 0
+    for module in self.modules():
+        if is_supported_instance(module):
+            flops_sum += module.__flops__
+    params_sum = get_model_parameters_number(self)
+    return flops_sum / batches_count, params_sum
+
+
+def start_flops_count(self):
+    """Activate the computation of mean flops consumption per image.
+
+    A method to activate the computation of mean flops consumption per image.
+    which will be available after ``add_flops_counting_methods()`` is called on
+    a desired net object. It should be called before running the network.
+    """
+    add_batch_counter_hook_function(self)
+
+    def add_flops_counter_hook_function(module):
+        if is_supported_instance(module):
+            if hasattr(module, '__flops_handle__'):
+                return
+
+            else:
+                handle = module.register_forward_hook(
+                    get_modules_mapping()[type(module)])
+
+            module.__flops_handle__ = handle
+
+    self.apply(partial(add_flops_counter_hook_function))
+
+
+def stop_flops_count(self):
+    """Stop computing the mean flops consumption per image.
+
+    A method to stop computing the mean flops consumption per image, which will
+    be available after ``add_flops_counting_methods()`` is called on a desired
+    net object. It can be called to pause the computation whenever.
+    """
+    remove_batch_counter_hook_function(self)
+    self.apply(remove_flops_counter_hook_function)
+
+
+def reset_flops_count(self):
+    """Reset statistics computed so far.
+
+    A method to Reset computed statistics, which will be available after
+    `add_flops_counting_methods()` is called on a desired net object.
+    """
+    add_batch_counter_variables_or_reset(self)
+    self.apply(add_flops_counter_variable_or_reset)
+
+
+# ---- Internal functions
+def empty_flops_counter_hook(module, input, output):
+    module.__flops__ += 0
+
+
+def upsample_flops_counter_hook(module, input, output):
+    output_size = output[0]
+    batch_size = output_size.shape[0]
+    output_elements_count = batch_size
+    for val in output_size.shape[1:]:
+        output_elements_count *= val
+    module.__flops__ += int(output_elements_count)
+
+
+def relu_flops_counter_hook(module, input, output):
+    active_elements_count = output.numel()
+    module.__flops__ += int(active_elements_count)
+
+
+def linear_flops_counter_hook(module, input, output):
+    input = input[0]
+    output_last_dim = output.shape[
+        -1]  # pytorch checks dimensions, so here we don't care much
+    module.__flops__ += int(np.prod(input.shape) * output_last_dim)
+
+
+def pool_flops_counter_hook(module, input, output):
+    input = input[0]
+    module.__flops__ += int(np.prod(input.shape))
+
+
+def norm_flops_counter_hook(module, input, output):
+    input = input[0]
+
+    batch_flops = np.prod(input.shape)
+    if (getattr(module, 'affine', False)
+            or getattr(module, 'elementwise_affine', False)):
+        batch_flops *= 2
+    module.__flops__ += int(batch_flops)
+
+
+def deconv_flops_counter_hook(conv_module, input, output):
+    # Can have multiple inputs, getting the first one
+    input = input[0]
+
+    batch_size = input.shape[0]
+    input_height, input_width = input.shape[2:]
+
+    kernel_height, kernel_width = conv_module.kernel_size
+    in_channels = conv_module.in_channels
+    out_channels = conv_module.out_channels
+    groups = conv_module.groups
+
+    filters_per_channel = out_channels // groups
+    conv_per_position_flops = (
+        kernel_height * kernel_width * in_channels * filters_per_channel)
+
+    active_elements_count = batch_size * input_height * input_width
+    overall_conv_flops = conv_per_position_flops * active_elements_count
+    bias_flops = 0
+    if conv_module.bias is not None:
+        output_height, output_width = output.shape[2:]
+        bias_flops = out_channels * batch_size * output_height * output_height
+    overall_flops = overall_conv_flops + bias_flops
+
+    conv_module.__flops__ += int(overall_flops)
+
+
+def conv_flops_counter_hook(conv_module, input, output):
+    # Can have multiple inputs, getting the first one
+    input = input[0]
+
+    batch_size = input.shape[0]
+    output_dims = list(output.shape[2:])
+
+    kernel_dims = list(conv_module.kernel_size)
+    in_channels = conv_module.in_channels
+    out_channels = conv_module.out_channels
+    groups = conv_module.groups
+
+    filters_per_channel = out_channels // groups
+    conv_per_position_flops = int(
+        np.prod(kernel_dims)) * in_channels * filters_per_channel
+
+    active_elements_count = batch_size * int(np.prod(output_dims))
+
+    overall_conv_flops = conv_per_position_flops * active_elements_count
+
+    bias_flops = 0
+
+    if conv_module.bias is not None:
+
+        bias_flops = out_channels * active_elements_count
+
+    overall_flops = overall_conv_flops + bias_flops
+
+    conv_module.__flops__ += int(overall_flops)
+
+
+def batch_counter_hook(module, input, output):
+    batch_size = 1
+    if len(input) > 0:
+        # Can have multiple inputs, getting the first one
+        input = input[0]
+        batch_size = len(input)
+    else:
+        pass
+        print('Warning! No positional inputs found for a module, '
+              'assuming batch size is 1.')
+    module.__batch_counter__ += batch_size
+
+
+def add_batch_counter_variables_or_reset(module):
+
+    module.__batch_counter__ = 0
+
+
+def add_batch_counter_hook_function(module):
+    if hasattr(module, '__batch_counter_handle__'):
+        return
+
+    handle = module.register_forward_hook(batch_counter_hook)
+    module.__batch_counter_handle__ = handle
+
+
+def remove_batch_counter_hook_function(module):
+    if hasattr(module, '__batch_counter_handle__'):
+        module.__batch_counter_handle__.remove()
+        del module.__batch_counter_handle__
+
+
+def add_flops_counter_variable_or_reset(module):
+    if is_supported_instance(module):
+        if hasattr(module, '__flops__') or hasattr(module, '__params__'):
+            print('Warning: variables __flops__ or __params__ are already '
+                  'defined for the module' + type(module).__name__ +
+                  ' ptflops can affect your code!')
+        module.__flops__ = 0
+        module.__params__ = get_model_parameters_number(module)
+
+
+def is_supported_instance(module):
+    if type(module) in get_modules_mapping():
+        return True
+    return False
+
+
+def remove_flops_counter_hook_function(module):
+    if is_supported_instance(module):
+        if hasattr(module, '__flops_handle__'):
+            module.__flops_handle__.remove()
+            del module.__flops_handle__
+
+
+def get_modules_mapping():
+    return {
+        # convolutions
+        nn.Conv1d: conv_flops_counter_hook,
+        nn.Conv2d: conv_flops_counter_hook,
+        mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
+        nn.Conv3d: conv_flops_counter_hook,
+        mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
+        # activations
+        nn.ReLU: relu_flops_counter_hook,
+        nn.PReLU: relu_flops_counter_hook,
+        nn.ELU: relu_flops_counter_hook,
+        nn.LeakyReLU: relu_flops_counter_hook,
+        nn.ReLU6: relu_flops_counter_hook,
+        # poolings
+        nn.MaxPool1d: pool_flops_counter_hook,
+        nn.AvgPool1d: pool_flops_counter_hook,
+        nn.AvgPool2d: pool_flops_counter_hook,
+        nn.MaxPool2d: pool_flops_counter_hook,
+        mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
+        nn.MaxPool3d: pool_flops_counter_hook,
+        mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
+        nn.AvgPool3d: pool_flops_counter_hook,
+        nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
+        nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
+        nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
+        nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
+        nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
+        nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
+        # normalizations
+        nn.BatchNorm1d: norm_flops_counter_hook,
+        nn.BatchNorm2d: norm_flops_counter_hook,
+        nn.BatchNorm3d: norm_flops_counter_hook,
+        nn.GroupNorm: norm_flops_counter_hook,
+        nn.InstanceNorm1d: norm_flops_counter_hook,
+        nn.InstanceNorm2d: norm_flops_counter_hook,
+        nn.InstanceNorm3d: norm_flops_counter_hook,
+        nn.LayerNorm: norm_flops_counter_hook,
+        # FC
+        nn.Linear: linear_flops_counter_hook,
+        mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
+        # Upscale
+        nn.Upsample: upsample_flops_counter_hook,
+        # Deconvolution
+        nn.ConvTranspose2d: deconv_flops_counter_hook,
+        mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
+    }
diff --git a/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py b/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb7076f80bf37f7931185bf0293ffcc1ce19c8ef
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/fuse_conv_bn.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+
+def _fuse_conv_bn(conv, bn):
+    """Fuse conv and bn into one module.
+
+    Args:
+        conv (nn.Module): Conv to be fused.
+        bn (nn.Module): BN to be fused.
+
+    Returns:
+        nn.Module: Fused module.
+    """
+    conv_w = conv.weight
+    conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+        bn.running_mean)
+
+    factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+    conv.weight = nn.Parameter(conv_w *
+                               factor.reshape([conv.out_channels, 1, 1, 1]))
+    conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+    return conv
+
+
+def fuse_conv_bn(module):
+    """Recursively fuse conv and bn in a module.
+
+    During inference, the functionary of batch norm layers is turned off
+    but only the mean and var alone channels are used, which exposes the
+    chance to fuse it with the preceding conv layers to save computations and
+    simplify network structures.
+
+    Args:
+        module (nn.Module): Module to be fused.
+
+    Returns:
+        nn.Module: Fused module.
+    """
+    last_conv = None
+    last_conv_name = None
+
+    for name, child in module.named_children():
+        if isinstance(child,
+                      (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
+            if last_conv is None:  # only fuse BN that is after Conv
+                continue
+            fused_conv = _fuse_conv_bn(last_conv, child)
+            module._modules[last_conv_name] = fused_conv
+            # To reduce changes, set BN as Identity instead of deleting it.
+            module._modules[name] = nn.Identity()
+            last_conv = None
+        elif isinstance(child, nn.Conv2d):
+            last_conv = child
+            last_conv_name = name
+        else:
+            fuse_conv_bn(child)
+    return module
diff --git a/annotator/uniformer/mmcv/cnn/utils/sync_bn.py b/annotator/uniformer/mmcv/cnn/utils/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f78f39181d75bb85c53e8c7c8eaf45690e9f0bee
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/sync_bn.py
@@ -0,0 +1,59 @@
+import torch
+
+import annotator.uniformer.mmcv as mmcv
+
+
+class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
+    """A general BatchNorm layer without input dimension check.
+
+    Reproduced from @kapily's work:
+    (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+    The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
+    is `_check_input_dim` that is designed for tensor sanity checks.
+    The check has been bypassed in this class for the convenience of converting
+    SyncBatchNorm.
+    """
+
+    def _check_input_dim(self, input):
+        return
+
+
+def revert_sync_batchnorm(module):
+    """Helper function to convert all `SyncBatchNorm` (SyncBN) and
+    `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
+    `BatchNormXd` layers.
+
+    Adapted from @kapily's work:
+    (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
+
+    Args:
+        module (nn.Module): The module containing `SyncBatchNorm` layers.
+
+    Returns:
+        module_output: The converted module with `BatchNormXd` layers.
+    """
+    module_output = module
+    module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
+    if hasattr(mmcv, 'ops'):
+        module_checklist.append(mmcv.ops.SyncBatchNorm)
+    if isinstance(module, tuple(module_checklist)):
+        module_output = _BatchNormXd(module.num_features, module.eps,
+                                     module.momentum, module.affine,
+                                     module.track_running_stats)
+        if module.affine:
+            # no_grad() may not be needed here but
+            # just to be consistent with `convert_sync_batchnorm()`
+            with torch.no_grad():
+                module_output.weight = module.weight
+                module_output.bias = module.bias
+        module_output.running_mean = module.running_mean
+        module_output.running_var = module.running_var
+        module_output.num_batches_tracked = module.num_batches_tracked
+        module_output.training = module.training
+        # qconfig exists in quantized models
+        if hasattr(module, 'qconfig'):
+            module_output.qconfig = module.qconfig
+    for name, child in module.named_children():
+        module_output.add_module(name, revert_sync_batchnorm(child))
+    del module
+    return module_output
diff --git a/annotator/uniformer/mmcv/cnn/utils/weight_init.py b/annotator/uniformer/mmcv/cnn/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..287a1d0bffe26e023029d48634d9b761deda7ba4
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/utils/weight_init.py
@@ -0,0 +1,684 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from annotator.uniformer.mmcv.utils import Registry, build_from_cfg, get_logger, print_log
+
+INITIALIZERS = Registry('initializer')
+
+
+def update_init_info(module, init_info):
+    """Update the `_params_init_info` in the module if the value of parameters
+    are changed.
+
+    Args:
+        module (obj:`nn.Module`): The module of PyTorch with a user-defined
+            attribute `_params_init_info` which records the initialization
+            information.
+        init_info (str): The string that describes the initialization.
+    """
+    assert hasattr(
+        module,
+        '_params_init_info'), f'Can not find `_params_init_info` in {module}'
+    for name, param in module.named_parameters():
+
+        assert param in module._params_init_info, (
+            f'Find a new :obj:`Parameter` '
+            f'named `{name}` during executing the '
+            f'`init_weights` of '
+            f'`{module.__class__.__name__}`. '
+            f'Please do not add or '
+            f'replace parameters during executing '
+            f'the `init_weights`. ')
+
+        # The parameter has been changed during executing the
+        # `init_weights` of module
+        mean_value = param.data.mean()
+        if module._params_init_info[param]['tmp_mean_value'] != mean_value:
+            module._params_init_info[param]['init_info'] = init_info
+            module._params_init_info[param]['tmp_mean_value'] = mean_value
+
+
+def constant_init(module, val, bias=0):
+    if hasattr(module, 'weight') and module.weight is not None:
+        nn.init.constant_(module.weight, val)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module, gain=1, bias=0, distribution='normal'):
+    assert distribution in ['uniform', 'normal']
+    if hasattr(module, 'weight') and module.weight is not None:
+        if distribution == 'uniform':
+            nn.init.xavier_uniform_(module.weight, gain=gain)
+        else:
+            nn.init.xavier_normal_(module.weight, gain=gain)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+    if hasattr(module, 'weight') and module.weight is not None:
+        nn.init.normal_(module.weight, mean, std)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def trunc_normal_init(module: nn.Module,
+                      mean: float = 0,
+                      std: float = 1,
+                      a: float = -2,
+                      b: float = 2,
+                      bias: float = 0) -> None:
+    if hasattr(module, 'weight') and module.weight is not None:
+        trunc_normal_(module.weight, mean, std, a, b)  # type: ignore
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)  # type: ignore
+
+
+def uniform_init(module, a=0, b=1, bias=0):
+    if hasattr(module, 'weight') and module.weight is not None:
+        nn.init.uniform_(module.weight, a, b)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(module,
+                 a=0,
+                 mode='fan_out',
+                 nonlinearity='relu',
+                 bias=0,
+                 distribution='normal'):
+    assert distribution in ['uniform', 'normal']
+    if hasattr(module, 'weight') and module.weight is not None:
+        if distribution == 'uniform':
+            nn.init.kaiming_uniform_(
+                module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+        else:
+            nn.init.kaiming_normal_(
+                module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def caffe2_xavier_init(module, bias=0):
+    # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+    # Acknowledgment to FAIR's internal code
+    kaiming_init(
+        module,
+        a=1,
+        mode='fan_in',
+        nonlinearity='leaky_relu',
+        bias=bias,
+        distribution='uniform')
+
+
+def bias_init_with_prob(prior_prob):
+    """initialize conv/fc bias value according to a given probability value."""
+    bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+    return bias_init
+
+
+def _get_bases_name(m):
+    return [b.__name__ for b in m.__class__.__bases__]
+
+
+class BaseInit(object):
+
+    def __init__(self, *, bias=0, bias_prob=None, layer=None):
+        self.wholemodule = False
+        if not isinstance(bias, (int, float)):
+            raise TypeError(f'bias must be a number, but got a {type(bias)}')
+
+        if bias_prob is not None:
+            if not isinstance(bias_prob, float):
+                raise TypeError(f'bias_prob type must be float, \
+                    but got {type(bias_prob)}')
+
+        if layer is not None:
+            if not isinstance(layer, (str, list)):
+                raise TypeError(f'layer must be a str or a list of str, \
+                    but got a {type(layer)}')
+        else:
+            layer = []
+
+        if bias_prob is not None:
+            self.bias = bias_init_with_prob(bias_prob)
+        else:
+            self.bias = bias
+        self.layer = [layer] if isinstance(layer, str) else layer
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}, bias={self.bias}'
+        return info
+
+
+@INITIALIZERS.register_module(name='Constant')
+class ConstantInit(BaseInit):
+    """Initialize module parameters with constant values.
+
+    Args:
+        val (int | float): the value to fill the weights in the module with
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self, val, **kwargs):
+        super().__init__(**kwargs)
+        self.val = val
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                constant_init(m, self.val, self.bias)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    constant_init(m, self.val, self.bias)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
+        return info
+
+
+@INITIALIZERS.register_module(name='Xavier')
+class XavierInit(BaseInit):
+    r"""Initialize module parameters with values according to the method
+    described in `Understanding the difficulty of training deep feedforward
+    neural networks - Glorot, X. & Bengio, Y. (2010).
+    <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
+
+    Args:
+        gain (int | float): an optional scaling factor. Defaults to 1.
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        distribution (str): distribution either be ``'normal'``
+            or ``'uniform'``. Defaults to ``'normal'``.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self, gain=1, distribution='normal', **kwargs):
+        super().__init__(**kwargs)
+        self.gain = gain
+        self.distribution = distribution
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                xavier_init(m, self.gain, self.bias, self.distribution)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    xavier_init(m, self.gain, self.bias, self.distribution)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: gain={self.gain}, ' \
+               f'distribution={self.distribution}, bias={self.bias}'
+        return info
+
+
+@INITIALIZERS.register_module(name='Normal')
+class NormalInit(BaseInit):
+    r"""Initialize module parameters with the values drawn from the normal
+    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+
+    Args:
+        mean (int | float):the mean of the normal distribution. Defaults to 0.
+        std (int | float): the standard deviation of the normal distribution.
+            Defaults to 1.
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+
+    """
+
+    def __init__(self, mean=0, std=1, **kwargs):
+        super().__init__(**kwargs)
+        self.mean = mean
+        self.std = std
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                normal_init(m, self.mean, self.std, self.bias)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    normal_init(m, self.mean, self.std, self.bias)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: mean={self.mean},' \
+               f' std={self.std}, bias={self.bias}'
+        return info
+
+
+@INITIALIZERS.register_module(name='TruncNormal')
+class TruncNormalInit(BaseInit):
+    r"""Initialize module parameters with the values drawn from the normal
+    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
+    outside :math:`[a, b]`.
+
+    Args:
+        mean (float): the mean of the normal distribution. Defaults to 0.
+        std (float):  the standard deviation of the normal distribution.
+            Defaults to 1.
+        a (float): The minimum cutoff value.
+        b ( float): The maximum cutoff value.
+        bias (float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+
+    """
+
+    def __init__(self,
+                 mean: float = 0,
+                 std: float = 1,
+                 a: float = -2,
+                 b: float = 2,
+                 **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.mean = mean
+        self.std = std
+        self.a = a
+        self.b = b
+
+    def __call__(self, module: nn.Module) -> None:
+
+        def init(m):
+            if self.wholemodule:
+                trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+                                  self.bias)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+                                      self.bias)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
+               f' mean={self.mean}, std={self.std}, bias={self.bias}'
+        return info
+
+
+@INITIALIZERS.register_module(name='Uniform')
+class UniformInit(BaseInit):
+    r"""Initialize module parameters with values drawn from the uniform
+    distribution :math:`\mathcal{U}(a, b)`.
+
+    Args:
+        a (int | float): the lower bound of the uniform distribution.
+            Defaults to 0.
+        b (int | float): the upper bound of the uniform distribution.
+            Defaults to 1.
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self, a=0, b=1, **kwargs):
+        super().__init__(**kwargs)
+        self.a = a
+        self.b = b
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                uniform_init(m, self.a, self.b, self.bias)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    uniform_init(m, self.a, self.b, self.bias)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: a={self.a},' \
+               f' b={self.b}, bias={self.bias}'
+        return info
+
+
+@INITIALIZERS.register_module(name='Kaiming')
+class KaimingInit(BaseInit):
+    r"""Initialize module parameters with the values according to the method
+    described in `Delving deep into rectifiers: Surpassing human-level
+    performance on ImageNet classification - He, K. et al. (2015).
+    <https://www.cv-foundation.org/openaccess/content_iccv_2015/
+    papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
+
+    Args:
+        a (int | float): the negative slope of the rectifier used after this
+            layer (only used with ``'leaky_relu'``). Defaults to 0.
+        mode (str):  either ``'fan_in'`` or ``'fan_out'``. Choosing
+            ``'fan_in'`` preserves the magnitude of the variance of the weights
+            in the forward pass. Choosing ``'fan_out'`` preserves the
+            magnitudes in the backwards pass. Defaults to ``'fan_out'``.
+        nonlinearity (str): the non-linear function (`nn.functional` name),
+            recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
+            Defaults to 'relu'.
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        distribution (str): distribution either be ``'normal'`` or
+            ``'uniform'``. Defaults to ``'normal'``.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self,
+                 a=0,
+                 mode='fan_out',
+                 nonlinearity='relu',
+                 distribution='normal',
+                 **kwargs):
+        super().__init__(**kwargs)
+        self.a = a
+        self.mode = mode
+        self.nonlinearity = nonlinearity
+        self.distribution = distribution
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                kaiming_init(m, self.a, self.mode, self.nonlinearity,
+                             self.bias, self.distribution)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    kaiming_init(m, self.a, self.mode, self.nonlinearity,
+                                 self.bias, self.distribution)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
+               f'nonlinearity={self.nonlinearity}, ' \
+               f'distribution ={self.distribution}, bias={self.bias}'
+        return info
+
+
+@INITIALIZERS.register_module(name='Caffe2Xavier')
+class Caffe2XavierInit(KaimingInit):
+    # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+    # Acknowledgment to FAIR's internal code
+    def __init__(self, **kwargs):
+        super().__init__(
+            a=1,
+            mode='fan_in',
+            nonlinearity='leaky_relu',
+            distribution='uniform',
+            **kwargs)
+
+    def __call__(self, module):
+        super().__call__(module)
+
+
+@INITIALIZERS.register_module(name='Pretrained')
+class PretrainedInit(object):
+    """Initialize module by loading a pretrained model.
+
+    Args:
+        checkpoint (str): the checkpoint file of the pretrained model should
+            be load.
+        prefix (str, optional): the prefix of a sub-module in the pretrained
+            model. it is for loading a part of the pretrained model to
+            initialize. For example, if we would like to only load the
+            backbone of a detector model, we can set ``prefix='backbone.'``.
+            Defaults to None.
+        map_location (str): map tensors into proper locations.
+    """
+
+    def __init__(self, checkpoint, prefix=None, map_location=None):
+        self.checkpoint = checkpoint
+        self.prefix = prefix
+        self.map_location = map_location
+
+    def __call__(self, module):
+        from annotator.uniformer.mmcv.runner import (_load_checkpoint_with_prefix, load_checkpoint,
+                                 load_state_dict)
+        logger = get_logger('mmcv')
+        if self.prefix is None:
+            print_log(f'load model from: {self.checkpoint}', logger=logger)
+            load_checkpoint(
+                module,
+                self.checkpoint,
+                map_location=self.map_location,
+                strict=False,
+                logger=logger)
+        else:
+            print_log(
+                f'load {self.prefix} in model from: {self.checkpoint}',
+                logger=logger)
+            state_dict = _load_checkpoint_with_prefix(
+                self.prefix, self.checkpoint, map_location=self.map_location)
+            load_state_dict(module, state_dict, strict=False, logger=logger)
+
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: load from {self.checkpoint}'
+        return info
+
+
+def _initialize(module, cfg, wholemodule=False):
+    func = build_from_cfg(cfg, INITIALIZERS)
+    # wholemodule flag is for override mode, there is no layer key in override
+    # and initializer will give init values for the whole module with the name
+    # in override.
+    func.wholemodule = wholemodule
+    func(module)
+
+
+def _initialize_override(module, override, cfg):
+    if not isinstance(override, (dict, list)):
+        raise TypeError(f'override must be a dict or a list of dict, \
+                but got {type(override)}')
+
+    override = [override] if isinstance(override, dict) else override
+
+    for override_ in override:
+
+        cp_override = copy.deepcopy(override_)
+        name = cp_override.pop('name', None)
+        if name is None:
+            raise ValueError('`override` must contain the key "name",'
+                             f'but got {cp_override}')
+        # if override only has name key, it means use args in init_cfg
+        if not cp_override:
+            cp_override.update(cfg)
+        # if override has name key and other args except type key, it will
+        # raise error
+        elif 'type' not in cp_override.keys():
+            raise ValueError(
+                f'`override` need "type" key, but got {cp_override}')
+
+        if hasattr(module, name):
+            _initialize(getattr(module, name), cp_override, wholemodule=True)
+        else:
+            raise RuntimeError(f'module did not have attribute {name}, '
+                               f'but init_cfg is {cp_override}.')
+
+
+def initialize(module, init_cfg):
+    """Initialize a module.
+
+    Args:
+        module (``torch.nn.Module``): the module will be initialized.
+        init_cfg (dict | list[dict]): initialization configuration dict to
+            define initializer. OpenMMLab has implemented 6 initializers
+            including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
+            ``Kaiming``, and ``Pretrained``.
+    Example:
+        >>> module = nn.Linear(2, 3, bias=True)
+        >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
+        >>> initialize(module, init_cfg)
+
+        >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
+        >>> # define key ``'layer'`` for initializing layer with different
+        >>> # configuration
+        >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
+                dict(type='Constant', layer='Linear', val=2)]
+        >>> initialize(module, init_cfg)
+
+        >>> # define key``'override'`` to initialize some specific part in
+        >>> # module
+        >>> class FooNet(nn.Module):
+        >>>     def __init__(self):
+        >>>         super().__init__()
+        >>>         self.feat = nn.Conv2d(3, 16, 3)
+        >>>         self.reg = nn.Conv2d(16, 10, 3)
+        >>>         self.cls = nn.Conv2d(16, 5, 3)
+        >>> model = FooNet()
+        >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
+        >>>     override=dict(type='Constant', name='reg', val=3, bias=4))
+        >>> initialize(model, init_cfg)
+
+        >>> model = ResNet(depth=50)
+        >>> # Initialize weights with the pretrained model.
+        >>> init_cfg = dict(type='Pretrained',
+                checkpoint='torchvision://resnet50')
+        >>> initialize(model, init_cfg)
+
+        >>> # Initialize weights of a sub-module with the specific part of
+        >>> # a pretrained model by using "prefix".
+        >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
+        >>>     'retinanet_r50_fpn_1x_coco/'\
+        >>>     'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
+        >>> init_cfg = dict(type='Pretrained',
+                checkpoint=url, prefix='backbone.')
+    """
+    if not isinstance(init_cfg, (dict, list)):
+        raise TypeError(f'init_cfg must be a dict or a list of dict, \
+                but got {type(init_cfg)}')
+
+    if isinstance(init_cfg, dict):
+        init_cfg = [init_cfg]
+
+    for cfg in init_cfg:
+        # should deeply copy the original config because cfg may be used by
+        # other modules, e.g., one init_cfg shared by multiple bottleneck
+        # blocks, the expected cfg will be changed after pop and will change
+        # the initialization behavior of other modules
+        cp_cfg = copy.deepcopy(cfg)
+        override = cp_cfg.pop('override', None)
+        _initialize(module, cp_cfg)
+
+        if override is not None:
+            cp_cfg.pop('layer', None)
+            _initialize_override(module, override, cp_cfg)
+        else:
+            # All attributes in module have same initialization.
+            pass
+
+
+def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
+                           b: float) -> Tensor:
+    # Method based on
+    # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+    # Modified from
+    # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn(
+            'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+            'The distribution of values may be incorrect.',
+            stacklevel=2)
+
+    with torch.no_grad():
+        # Values are generated by using a truncated uniform distribution and
+        # then using the inverse CDF for the normal distribution.
+        # Get upper and lower cdf values
+        lower = norm_cdf((a - mean) / std)
+        upper = norm_cdf((b - mean) / std)
+
+        # Uniformly fill tensor with values from [lower, upper], then translate
+        # to [2lower-1, 2upper-1].
+        tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+        # Use inverse cdf transform for normal distribution to get truncated
+        # standard normal
+        tensor.erfinv_()
+
+        # Transform to proper mean, std
+        tensor.mul_(std * math.sqrt(2.))
+        tensor.add_(mean)
+
+        # Clamp to ensure it's in the proper range
+        tensor.clamp_(min=a, max=b)
+        return tensor
+
+
+def trunc_normal_(tensor: Tensor,
+                  mean: float = 0.,
+                  std: float = 1.,
+                  a: float = -2.,
+                  b: float = 2.) -> Tensor:
+    r"""Fills the input Tensor with values drawn from a truncated
+    normal distribution. The values are effectively drawn from the
+    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+    with values outside :math:`[a, b]` redrawn until they are within
+    the bounds. The method used for generating the random values works
+    best when :math:`a \leq \text{mean} \leq b`.
+
+    Modified from
+    https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+
+    Args:
+        tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
+        mean (float): the mean of the normal distribution.
+        std (float): the standard deviation of the normal distribution.
+        a (float): the minimum cutoff value.
+        b (float): the maximum cutoff value.
+    """
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/annotator/uniformer/mmcv/cnn/vgg.py b/annotator/uniformer/mmcv/cnn/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..8778b649561a45a9652b1a15a26c2d171e58f3e1
--- /dev/null
+++ b/annotator/uniformer/mmcv/cnn/vgg.py
@@ -0,0 +1,175 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.nn as nn
+
+from .utils import constant_init, kaiming_init, normal_init
+
+
+def conv3x3(in_planes, out_planes, dilation=1):
+    """3x3 convolution with padding."""
+    return nn.Conv2d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        padding=dilation,
+        dilation=dilation)
+
+
+def make_vgg_layer(inplanes,
+                   planes,
+                   num_blocks,
+                   dilation=1,
+                   with_bn=False,
+                   ceil_mode=False):
+    layers = []
+    for _ in range(num_blocks):
+        layers.append(conv3x3(inplanes, planes, dilation))
+        if with_bn:
+            layers.append(nn.BatchNorm2d(planes))
+        layers.append(nn.ReLU(inplace=True))
+        inplanes = planes
+    layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
+
+    return layers
+
+
+class VGG(nn.Module):
+    """VGG backbone.
+
+    Args:
+        depth (int): Depth of vgg, from {11, 13, 16, 19}.
+        with_bn (bool): Use BatchNorm or not.
+        num_classes (int): number of classes for classification.
+        num_stages (int): VGG stages, normally 5.
+        dilations (Sequence[int]): Dilation of each stage.
+        out_indices (Sequence[int]): Output from which stages.
+        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+            not freezing any parameters.
+        bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+            running stats (mean and var).
+        bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+    """
+
+    arch_settings = {
+        11: (1, 1, 2, 2, 2),
+        13: (2, 2, 2, 2, 2),
+        16: (2, 2, 3, 3, 3),
+        19: (2, 2, 4, 4, 4)
+    }
+
+    def __init__(self,
+                 depth,
+                 with_bn=False,
+                 num_classes=-1,
+                 num_stages=5,
+                 dilations=(1, 1, 1, 1, 1),
+                 out_indices=(0, 1, 2, 3, 4),
+                 frozen_stages=-1,
+                 bn_eval=True,
+                 bn_frozen=False,
+                 ceil_mode=False,
+                 with_last_pool=True):
+        super(VGG, self).__init__()
+        if depth not in self.arch_settings:
+            raise KeyError(f'invalid depth {depth} for vgg')
+        assert num_stages >= 1 and num_stages <= 5
+        stage_blocks = self.arch_settings[depth]
+        self.stage_blocks = stage_blocks[:num_stages]
+        assert len(dilations) == num_stages
+        assert max(out_indices) <= num_stages
+
+        self.num_classes = num_classes
+        self.out_indices = out_indices
+        self.frozen_stages = frozen_stages
+        self.bn_eval = bn_eval
+        self.bn_frozen = bn_frozen
+
+        self.inplanes = 3
+        start_idx = 0
+        vgg_layers = []
+        self.range_sub_modules = []
+        for i, num_blocks in enumerate(self.stage_blocks):
+            num_modules = num_blocks * (2 + with_bn) + 1
+            end_idx = start_idx + num_modules
+            dilation = dilations[i]
+            planes = 64 * 2**i if i < 4 else 512
+            vgg_layer = make_vgg_layer(
+                self.inplanes,
+                planes,
+                num_blocks,
+                dilation=dilation,
+                with_bn=with_bn,
+                ceil_mode=ceil_mode)
+            vgg_layers.extend(vgg_layer)
+            self.inplanes = planes
+            self.range_sub_modules.append([start_idx, end_idx])
+            start_idx = end_idx
+        if not with_last_pool:
+            vgg_layers.pop(-1)
+            self.range_sub_modules[-1][1] -= 1
+        self.module_name = 'features'
+        self.add_module(self.module_name, nn.Sequential(*vgg_layers))
+
+        if self.num_classes > 0:
+            self.classifier = nn.Sequential(
+                nn.Linear(512 * 7 * 7, 4096),
+                nn.ReLU(True),
+                nn.Dropout(),
+                nn.Linear(4096, 4096),
+                nn.ReLU(True),
+                nn.Dropout(),
+                nn.Linear(4096, num_classes),
+            )
+
+    def init_weights(self, pretrained=None):
+        if isinstance(pretrained, str):
+            logger = logging.getLogger()
+            from ..runner import load_checkpoint
+            load_checkpoint(self, pretrained, strict=False, logger=logger)
+        elif pretrained is None:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d):
+                    kaiming_init(m)
+                elif isinstance(m, nn.BatchNorm2d):
+                    constant_init(m, 1)
+                elif isinstance(m, nn.Linear):
+                    normal_init(m, std=0.01)
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def forward(self, x):
+        outs = []
+        vgg_layers = getattr(self, self.module_name)
+        for i in range(len(self.stage_blocks)):
+            for j in range(*self.range_sub_modules[i]):
+                vgg_layer = vgg_layers[j]
+                x = vgg_layer(x)
+            if i in self.out_indices:
+                outs.append(x)
+        if self.num_classes > 0:
+            x = x.view(x.size(0), -1)
+            x = self.classifier(x)
+            outs.append(x)
+        if len(outs) == 1:
+            return outs[0]
+        else:
+            return tuple(outs)
+
+    def train(self, mode=True):
+        super(VGG, self).train(mode)
+        if self.bn_eval:
+            for m in self.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eval()
+                    if self.bn_frozen:
+                        for params in m.parameters():
+                            params.requires_grad = False
+        vgg_layers = getattr(self, self.module_name)
+        if mode and self.frozen_stages >= 0:
+            for i in range(self.frozen_stages):
+                for j in range(*self.range_sub_modules[i]):
+                    mod = vgg_layers[j]
+                    mod.eval()
+                    for param in mod.parameters():
+                        param.requires_grad = False
diff --git a/annotator/uniformer/mmcv/engine/__init__.py b/annotator/uniformer/mmcv/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3193b7f664e19ce2458d81c836597fa22e4bb082
--- /dev/null
+++ b/annotator/uniformer/mmcv/engine/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .test import (collect_results_cpu, collect_results_gpu, multi_gpu_test,
+                   single_gpu_test)
+
+__all__ = [
+    'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test',
+    'single_gpu_test'
+]
diff --git a/annotator/uniformer/mmcv/engine/test.py b/annotator/uniformer/mmcv/engine/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dbeef271db634ec2dadfda3bc0b5ef9c7a677ff
--- /dev/null
+++ b/annotator/uniformer/mmcv/engine/test.py
@@ -0,0 +1,202 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+
+import torch
+import torch.distributed as dist
+
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.runner import get_dist_info
+
+
+def single_gpu_test(model, data_loader):
+    """Test model with a single gpu.
+
+    This method tests model with a single gpu and displays test progress bar.
+
+    Args:
+        model (nn.Module): Model to be tested.
+        data_loader (nn.Dataloader): Pytorch data loader.
+
+    Returns:
+        list: The prediction results.
+    """
+    model.eval()
+    results = []
+    dataset = data_loader.dataset
+    prog_bar = mmcv.ProgressBar(len(dataset))
+    for data in data_loader:
+        with torch.no_grad():
+            result = model(return_loss=False, **data)
+        results.extend(result)
+
+        # Assume result has the same length of batch_size
+        # refer to https://github.com/open-mmlab/mmcv/issues/985
+        batch_size = len(result)
+        for _ in range(batch_size):
+            prog_bar.update()
+    return results
+
+
+def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
+    """Test model with multiple gpus.
+
+    This method tests model with multiple gpus and collects the results
+    under two different modes: gpu and cpu modes. By setting
+    ``gpu_collect=True``, it encodes results to gpu tensors and use gpu
+    communication for results collection. On cpu mode it saves the results on
+    different gpus to ``tmpdir`` and collects them by the rank 0 worker.
+
+    Args:
+        model (nn.Module): Model to be tested.
+        data_loader (nn.Dataloader): Pytorch data loader.
+        tmpdir (str): Path of directory to save the temporary results from
+            different gpus under cpu mode.
+        gpu_collect (bool): Option to use either gpu or cpu to collect results.
+
+    Returns:
+        list: The prediction results.
+    """
+    model.eval()
+    results = []
+    dataset = data_loader.dataset
+    rank, world_size = get_dist_info()
+    if rank == 0:
+        prog_bar = mmcv.ProgressBar(len(dataset))
+    time.sleep(2)  # This line can prevent deadlock problem in some cases.
+    for i, data in enumerate(data_loader):
+        with torch.no_grad():
+            result = model(return_loss=False, **data)
+        results.extend(result)
+
+        if rank == 0:
+            batch_size = len(result)
+            batch_size_all = batch_size * world_size
+            if batch_size_all + prog_bar.completed > len(dataset):
+                batch_size_all = len(dataset) - prog_bar.completed
+            for _ in range(batch_size_all):
+                prog_bar.update()
+
+    # collect results from all ranks
+    if gpu_collect:
+        results = collect_results_gpu(results, len(dataset))
+    else:
+        results = collect_results_cpu(results, len(dataset), tmpdir)
+    return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+    """Collect results under cpu mode.
+
+    On cpu mode, this function will save the results on different gpus to
+    ``tmpdir`` and collect them by the rank 0 worker.
+
+    Args:
+        result_part (list): Result list containing result parts
+            to be collected.
+        size (int): Size of the results, commonly equal to length of
+            the results.
+        tmpdir (str | None): temporal directory for collected results to
+            store. If set to None, it will create a random temporal directory
+            for it.
+
+    Returns:
+        list: The collected results.
+    """
+    rank, world_size = get_dist_info()
+    # create a tmp dir if it is not specified
+    if tmpdir is None:
+        MAX_LEN = 512
+        # 32 is whitespace
+        dir_tensor = torch.full((MAX_LEN, ),
+                                32,
+                                dtype=torch.uint8,
+                                device='cuda')
+        if rank == 0:
+            mmcv.mkdir_or_exist('.dist_test')
+            tmpdir = tempfile.mkdtemp(dir='.dist_test')
+            tmpdir = torch.tensor(
+                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+            dir_tensor[:len(tmpdir)] = tmpdir
+        dist.broadcast(dir_tensor, 0)
+        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+    else:
+        mmcv.mkdir_or_exist(tmpdir)
+    # dump the part result to the dir
+    mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
+    dist.barrier()
+    # collect all parts
+    if rank != 0:
+        return None
+    else:
+        # load results of all parts from tmp dir
+        part_list = []
+        for i in range(world_size):
+            part_file = osp.join(tmpdir, f'part_{i}.pkl')
+            part_result = mmcv.load(part_file)
+            # When data is severely insufficient, an empty part_result
+            # on a certain gpu could makes the overall outputs empty.
+            if part_result:
+                part_list.append(part_result)
+        # sort the results
+        ordered_results = []
+        for res in zip(*part_list):
+            ordered_results.extend(list(res))
+        # the dataloader may pad some samples
+        ordered_results = ordered_results[:size]
+        # remove tmp dir
+        shutil.rmtree(tmpdir)
+        return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+    """Collect results under gpu mode.
+
+    On gpu mode, this function will encode results to gpu tensors and use gpu
+    communication for results collection.
+
+    Args:
+        result_part (list): Result list containing result parts
+            to be collected.
+        size (int): Size of the results, commonly equal to length of
+            the results.
+
+    Returns:
+        list: The collected results.
+    """
+    rank, world_size = get_dist_info()
+    # dump result part to tensor with pickle
+    part_tensor = torch.tensor(
+        bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+    # gather all result part tensor shape
+    shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+    shape_list = [shape_tensor.clone() for _ in range(world_size)]
+    dist.all_gather(shape_list, shape_tensor)
+    # padding result part tensor to max length
+    shape_max = torch.tensor(shape_list).max()
+    part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+    part_send[:shape_tensor[0]] = part_tensor
+    part_recv_list = [
+        part_tensor.new_zeros(shape_max) for _ in range(world_size)
+    ]
+    # gather all result part
+    dist.all_gather(part_recv_list, part_send)
+
+    if rank == 0:
+        part_list = []
+        for recv, shape in zip(part_recv_list, shape_list):
+            part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())
+            # When data is severely insufficient, an empty part_result
+            # on a certain gpu could makes the overall outputs empty.
+            if part_result:
+                part_list.append(part_result)
+        # sort the results
+        ordered_results = []
+        for res in zip(*part_list):
+            ordered_results.extend(list(res))
+        # the dataloader may pad some samples
+        ordered_results = ordered_results[:size]
+        return ordered_results
diff --git a/annotator/uniformer/mmcv/fileio/__init__.py b/annotator/uniformer/mmcv/fileio/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2051b85f7e59bff7bdbaa131849ce8cd31f059a4
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .file_client import BaseStorageBackend, FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+from .io import dump, load, register_handler
+from .parse import dict_from_file, list_from_file
+
+__all__ = [
+    'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
+    'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
+    'list_from_file', 'dict_from_file'
+]
diff --git a/annotator/uniformer/mmcv/fileio/file_client.py b/annotator/uniformer/mmcv/fileio/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..950f0c1aeab14b8e308a7455ccd64a95b5d98add
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/file_client.py
@@ -0,0 +1,1148 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import os
+import os.path as osp
+import re
+import tempfile
+import warnings
+from abc import ABCMeta, abstractmethod
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Iterable, Iterator, Optional, Tuple, Union
+from urllib.request import urlopen
+
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.utils.misc import has_method
+from annotator.uniformer.mmcv.utils.path import is_filepath
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+    """Abstract class of storage backends.
+
+    All backends need to implement two apis: ``get()`` and ``get_text()``.
+    ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+    as texts.
+    """
+
+    # a flag to indicate whether the backend can create a symlink for a file
+    _allow_symlink = False
+
+    @property
+    def name(self):
+        return self.__class__.__name__
+
+    @property
+    def allow_symlink(self):
+        return self._allow_symlink
+
+    @abstractmethod
+    def get(self, filepath):
+        pass
+
+    @abstractmethod
+    def get_text(self, filepath):
+        pass
+
+
+class CephBackend(BaseStorageBackend):
+    """Ceph storage backend (for internal use).
+
+    Args:
+        path_mapping (dict|None): path mapping dict from local path to Petrel
+            path. When ``path_mapping={'src': 'dst'}``, ``src`` in ``filepath``
+            will be replaced by ``dst``. Default: None.
+
+    .. warning::
+        :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+        please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+    """
+
+    def __init__(self, path_mapping=None):
+        try:
+            import ceph
+        except ImportError:
+            raise ImportError('Please install ceph to enable CephBackend.')
+
+        warnings.warn(
+            'CephBackend will be deprecated, please use PetrelBackend instead')
+        self._client = ceph.S3Client()
+        assert isinstance(path_mapping, dict) or path_mapping is None
+        self.path_mapping = path_mapping
+
+    def get(self, filepath):
+        filepath = str(filepath)
+        if self.path_mapping is not None:
+            for k, v in self.path_mapping.items():
+                filepath = filepath.replace(k, v)
+        value = self._client.Get(filepath)
+        value_buf = memoryview(value)
+        return value_buf
+
+    def get_text(self, filepath, encoding=None):
+        raise NotImplementedError
+
+
+class PetrelBackend(BaseStorageBackend):
+    """Petrel storage backend (for internal use).
+
+    PetrelBackend supports reading and writing data to multiple clusters.
+    If the file path contains the cluster name, PetrelBackend will read data
+    from specified cluster or write data to it. Otherwise, PetrelBackend will
+    access the default cluster.
+
+    Args:
+        path_mapping (dict, optional): Path mapping dict from local path to
+            Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
+            ``filepath`` will be replaced by ``dst``. Default: None.
+        enable_mc (bool, optional): Whether to enable memcached support.
+            Default: True.
+
+    Examples:
+        >>> filepath1 = 's3://path/of/file'
+        >>> filepath2 = 'cluster-name:s3://path/of/file'
+        >>> client = PetrelBackend()
+        >>> client.get(filepath1)  # get data from default cluster
+        >>> client.get(filepath2)  # get data from 'cluster-name' cluster
+    """
+
+    def __init__(self,
+                 path_mapping: Optional[dict] = None,
+                 enable_mc: bool = True):
+        try:
+            from petrel_client import client
+        except ImportError:
+            raise ImportError('Please install petrel_client to enable '
+                              'PetrelBackend.')
+
+        self._client = client.Client(enable_mc=enable_mc)
+        assert isinstance(path_mapping, dict) or path_mapping is None
+        self.path_mapping = path_mapping
+
+    def _map_path(self, filepath: Union[str, Path]) -> str:
+        """Map ``filepath`` to a string path whose prefix will be replaced by
+        :attr:`self.path_mapping`.
+
+        Args:
+            filepath (str): Path to be mapped.
+        """
+        filepath = str(filepath)
+        if self.path_mapping is not None:
+            for k, v in self.path_mapping.items():
+                filepath = filepath.replace(k, v)
+        return filepath
+
+    def _format_path(self, filepath: str) -> str:
+        """Convert a ``filepath`` to standard format of petrel oss.
+
+        If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
+        environment, the ``filepath`` will be the format of
+        's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
+        above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
+
+        Args:
+            filepath (str): Path to be formatted.
+        """
+        return re.sub(r'\\+', '/', filepath)
+
+    def get(self, filepath: Union[str, Path]) -> memoryview:
+        """Read data from a given ``filepath`` with 'rb' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+
+        Returns:
+            memoryview: A memory view of expected bytes object to avoid
+                copying. The memoryview object can be converted to bytes by
+                ``value_buf.tobytes()``.
+        """
+        filepath = self._map_path(filepath)
+        filepath = self._format_path(filepath)
+        value = self._client.Get(filepath)
+        value_buf = memoryview(value)
+        return value_buf
+
+    def get_text(self,
+                 filepath: Union[str, Path],
+                 encoding: str = 'utf-8') -> str:
+        """Read data from a given ``filepath`` with 'r' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+
+        Returns:
+            str: Expected text reading from ``filepath``.
+        """
+        return str(self.get(filepath), encoding=encoding)
+
+    def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        """Save data to a given ``filepath``.
+
+        Args:
+            obj (bytes): Data to be saved.
+            filepath (str or Path): Path to write data.
+        """
+        filepath = self._map_path(filepath)
+        filepath = self._format_path(filepath)
+        self._client.put(filepath, obj)
+
+    def put_text(self,
+                 obj: str,
+                 filepath: Union[str, Path],
+                 encoding: str = 'utf-8') -> None:
+        """Save data to a given ``filepath``.
+
+        Args:
+            obj (str): Data to be written.
+            filepath (str or Path): Path to write data.
+            encoding (str): The encoding format used to encode the ``obj``.
+                Default: 'utf-8'.
+        """
+        self.put(bytes(obj, encoding=encoding), filepath)
+
+    def remove(self, filepath: Union[str, Path]) -> None:
+        """Remove a file.
+
+        Args:
+            filepath (str or Path): Path to be removed.
+        """
+        if not has_method(self._client, 'delete'):
+            raise NotImplementedError(
+                ('Current version of Petrel Python SDK has not supported '
+                 'the `delete` method, please use a higher version or dev'
+                 ' branch instead.'))
+
+        filepath = self._map_path(filepath)
+        filepath = self._format_path(filepath)
+        self._client.delete(filepath)
+
+    def exists(self, filepath: Union[str, Path]) -> bool:
+        """Check whether a file path exists.
+
+        Args:
+            filepath (str or Path): Path to be checked whether exists.
+
+        Returns:
+            bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+        """
+        if not (has_method(self._client, 'contains')
+                and has_method(self._client, 'isdir')):
+            raise NotImplementedError(
+                ('Current version of Petrel Python SDK has not supported '
+                 'the `contains` and `isdir` methods, please use a higher'
+                 'version or dev branch instead.'))
+
+        filepath = self._map_path(filepath)
+        filepath = self._format_path(filepath)
+        return self._client.contains(filepath) or self._client.isdir(filepath)
+
+    def isdir(self, filepath: Union[str, Path]) -> bool:
+        """Check whether a file path is a directory.
+
+        Args:
+            filepath (str or Path): Path to be checked whether it is a
+                directory.
+
+        Returns:
+            bool: Return ``True`` if ``filepath`` points to a directory,
+                ``False`` otherwise.
+        """
+        if not has_method(self._client, 'isdir'):
+            raise NotImplementedError(
+                ('Current version of Petrel Python SDK has not supported '
+                 'the `isdir` method, please use a higher version or dev'
+                 ' branch instead.'))
+
+        filepath = self._map_path(filepath)
+        filepath = self._format_path(filepath)
+        return self._client.isdir(filepath)
+
+    def isfile(self, filepath: Union[str, Path]) -> bool:
+        """Check whether a file path is a file.
+
+        Args:
+            filepath (str or Path): Path to be checked whether it is a file.
+
+        Returns:
+            bool: Return ``True`` if ``filepath`` points to a file, ``False``
+                otherwise.
+        """
+        if not has_method(self._client, 'contains'):
+            raise NotImplementedError(
+                ('Current version of Petrel Python SDK has not supported '
+                 'the `contains` method, please use a higher version or '
+                 'dev branch instead.'))
+
+        filepath = self._map_path(filepath)
+        filepath = self._format_path(filepath)
+        return self._client.contains(filepath)
+
+    def join_path(self, filepath: Union[str, Path],
+                  *filepaths: Union[str, Path]) -> str:
+        """Concatenate all file paths.
+
+        Args:
+            filepath (str or Path): Path to be concatenated.
+
+        Returns:
+            str: The result after concatenation.
+        """
+        filepath = self._format_path(self._map_path(filepath))
+        if filepath.endswith('/'):
+            filepath = filepath[:-1]
+        formatted_paths = [filepath]
+        for path in filepaths:
+            formatted_paths.append(self._format_path(self._map_path(path)))
+        return '/'.join(formatted_paths)
+
+    @contextmanager
+    def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+        """Download a file from ``filepath`` and return a temporary path.
+
+        ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+        can be called with ``with`` statement, and when exists from the
+        ``with`` statement, the temporary path will be released.
+
+        Args:
+            filepath (str | Path): Download a file from ``filepath``.
+
+        Examples:
+            >>> client = PetrelBackend()
+            >>> # After existing from the ``with`` clause,
+            >>> # the path will be removed
+            >>> with client.get_local_path('s3://path/of/your/file') as path:
+            ...     # do something here
+
+        Yields:
+            Iterable[str]: Only yield one temporary path.
+        """
+        filepath = self._map_path(filepath)
+        filepath = self._format_path(filepath)
+        assert self.isfile(filepath)
+        try:
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(self.get(filepath))
+            f.close()
+            yield f.name
+        finally:
+            os.remove(f.name)
+
+    def list_dir_or_file(self,
+                         dir_path: Union[str, Path],
+                         list_dir: bool = True,
+                         list_file: bool = True,
+                         suffix: Optional[Union[str, Tuple[str]]] = None,
+                         recursive: bool = False) -> Iterator[str]:
+        """Scan a directory to find the interested directories or files in
+        arbitrary order.
+
+        Note:
+            Petrel has no concept of directories but it simulates the directory
+            hierarchy in the filesystem through public prefixes. In addition,
+            if the returned path ends with '/', it means the path is a public
+            prefix which is a logical directory.
+
+        Note:
+            :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+            In addition, the returned path of directory will not contains the
+            suffix '/' which is consistent with other backends.
+
+        Args:
+            dir_path (str | Path): Path of the directory.
+            list_dir (bool): List the directories. Default: True.
+            list_file (bool): List the path of files. Default: True.
+            suffix (str or tuple[str], optional):  File suffix
+                that we are interested in. Default: None.
+            recursive (bool): If set to True, recursively scan the
+                directory. Default: False.
+
+        Yields:
+            Iterable[str]: A relative path to ``dir_path``.
+        """
+        if not has_method(self._client, 'list'):
+            raise NotImplementedError(
+                ('Current version of Petrel Python SDK has not supported '
+                 'the `list` method, please use a higher version or dev'
+                 ' branch instead.'))
+
+        dir_path = self._map_path(dir_path)
+        dir_path = self._format_path(dir_path)
+        if list_dir and suffix is not None:
+            raise TypeError(
+                '`list_dir` should be False when `suffix` is not None')
+
+        if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+            raise TypeError('`suffix` must be a string or tuple of strings')
+
+        # Petrel's simulated directory hierarchy assumes that directory paths
+        # should end with `/`
+        if not dir_path.endswith('/'):
+            dir_path += '/'
+
+        root = dir_path
+
+        def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+                              recursive):
+            for path in self._client.list(dir_path):
+                # the `self.isdir` is not used here to determine whether path
+                # is a directory, because `self.isdir` relies on
+                # `self._client.list`
+                if path.endswith('/'):  # a directory path
+                    next_dir_path = self.join_path(dir_path, path)
+                    if list_dir:
+                        # get the relative path and exclude the last
+                        # character '/'
+                        rel_dir = next_dir_path[len(root):-1]
+                        yield rel_dir
+                    if recursive:
+                        yield from _list_dir_or_file(next_dir_path, list_dir,
+                                                     list_file, suffix,
+                                                     recursive)
+                else:  # a file path
+                    absolute_path = self.join_path(dir_path, path)
+                    rel_path = absolute_path[len(root):]
+                    if (suffix is None
+                            or rel_path.endswith(suffix)) and list_file:
+                        yield rel_path
+
+        return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+                                 recursive)
+
+
+class MemcachedBackend(BaseStorageBackend):
+    """Memcached storage backend.
+
+    Attributes:
+        server_list_cfg (str): Config file for memcached server list.
+        client_cfg (str): Config file for memcached client.
+        sys_path (str | None): Additional path to be appended to `sys.path`.
+            Default: None.
+    """
+
+    def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+        if sys_path is not None:
+            import sys
+            sys.path.append(sys_path)
+        try:
+            import mc
+        except ImportError:
+            raise ImportError(
+                'Please install memcached to enable MemcachedBackend.')
+
+        self.server_list_cfg = server_list_cfg
+        self.client_cfg = client_cfg
+        self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
+                                                      self.client_cfg)
+        # mc.pyvector servers as a point which points to a memory cache
+        self._mc_buffer = mc.pyvector()
+
+    def get(self, filepath):
+        filepath = str(filepath)
+        import mc
+        self._client.Get(filepath, self._mc_buffer)
+        value_buf = mc.ConvertBuffer(self._mc_buffer)
+        return value_buf
+
+    def get_text(self, filepath, encoding=None):
+        raise NotImplementedError
+
+
+class LmdbBackend(BaseStorageBackend):
+    """Lmdb storage backend.
+
+    Args:
+        db_path (str): Lmdb database path.
+        readonly (bool, optional): Lmdb environment parameter. If True,
+            disallow any write operations. Default: True.
+        lock (bool, optional): Lmdb environment parameter. If False, when
+            concurrent access occurs, do not lock the database. Default: False.
+        readahead (bool, optional): Lmdb environment parameter. If False,
+            disable the OS filesystem readahead mechanism, which may improve
+            random read performance when a database is larger than RAM.
+            Default: False.
+
+    Attributes:
+        db_path (str): Lmdb database path.
+    """
+
+    def __init__(self,
+                 db_path,
+                 readonly=True,
+                 lock=False,
+                 readahead=False,
+                 **kwargs):
+        try:
+            import lmdb
+        except ImportError:
+            raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+        self.db_path = str(db_path)
+        self._client = lmdb.open(
+            self.db_path,
+            readonly=readonly,
+            lock=lock,
+            readahead=readahead,
+            **kwargs)
+
+    def get(self, filepath):
+        """Get values according to the filepath.
+
+        Args:
+            filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+        """
+        filepath = str(filepath)
+        with self._client.begin(write=False) as txn:
+            value_buf = txn.get(filepath.encode('ascii'))
+        return value_buf
+
+    def get_text(self, filepath, encoding=None):
+        raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+    """Raw hard disks storage backend."""
+
+    _allow_symlink = True
+
+    def get(self, filepath: Union[str, Path]) -> bytes:
+        """Read data from a given ``filepath`` with 'rb' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+
+        Returns:
+            bytes: Expected bytes object.
+        """
+        with open(filepath, 'rb') as f:
+            value_buf = f.read()
+        return value_buf
+
+    def get_text(self,
+                 filepath: Union[str, Path],
+                 encoding: str = 'utf-8') -> str:
+        """Read data from a given ``filepath`` with 'r' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+
+        Returns:
+            str: Expected text reading from ``filepath``.
+        """
+        with open(filepath, 'r', encoding=encoding) as f:
+            value_buf = f.read()
+        return value_buf
+
+    def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        """Write data to a given ``filepath`` with 'wb' mode.
+
+        Note:
+            ``put`` will create a directory if the directory of ``filepath``
+            does not exist.
+
+        Args:
+            obj (bytes): Data to be written.
+            filepath (str or Path): Path to write data.
+        """
+        mmcv.mkdir_or_exist(osp.dirname(filepath))
+        with open(filepath, 'wb') as f:
+            f.write(obj)
+
+    def put_text(self,
+                 obj: str,
+                 filepath: Union[str, Path],
+                 encoding: str = 'utf-8') -> None:
+        """Write data to a given ``filepath`` with 'w' mode.
+
+        Note:
+            ``put_text`` will create a directory if the directory of
+            ``filepath`` does not exist.
+
+        Args:
+            obj (str): Data to be written.
+            filepath (str or Path): Path to write data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+        """
+        mmcv.mkdir_or_exist(osp.dirname(filepath))
+        with open(filepath, 'w', encoding=encoding) as f:
+            f.write(obj)
+
+    def remove(self, filepath: Union[str, Path]) -> None:
+        """Remove a file.
+
+        Args:
+            filepath (str or Path): Path to be removed.
+        """
+        os.remove(filepath)
+
+    def exists(self, filepath: Union[str, Path]) -> bool:
+        """Check whether a file path exists.
+
+        Args:
+            filepath (str or Path): Path to be checked whether exists.
+
+        Returns:
+            bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+        """
+        return osp.exists(filepath)
+
+    def isdir(self, filepath: Union[str, Path]) -> bool:
+        """Check whether a file path is a directory.
+
+        Args:
+            filepath (str or Path): Path to be checked whether it is a
+                directory.
+
+        Returns:
+            bool: Return ``True`` if ``filepath`` points to a directory,
+                ``False`` otherwise.
+        """
+        return osp.isdir(filepath)
+
+    def isfile(self, filepath: Union[str, Path]) -> bool:
+        """Check whether a file path is a file.
+
+        Args:
+            filepath (str or Path): Path to be checked whether it is a file.
+
+        Returns:
+            bool: Return ``True`` if ``filepath`` points to a file, ``False``
+                otherwise.
+        """
+        return osp.isfile(filepath)
+
+    def join_path(self, filepath: Union[str, Path],
+                  *filepaths: Union[str, Path]) -> str:
+        """Concatenate all file paths.
+
+        Join one or more filepath components intelligently. The return value
+        is the concatenation of filepath and any members of *filepaths.
+
+        Args:
+            filepath (str or Path): Path to be concatenated.
+
+        Returns:
+            str: The result of concatenation.
+        """
+        return osp.join(filepath, *filepaths)
+
+    @contextmanager
+    def get_local_path(
+            self, filepath: Union[str, Path]) -> Iterable[Union[str, Path]]:
+        """Only for unified API and do nothing."""
+        yield filepath
+
+    def list_dir_or_file(self,
+                         dir_path: Union[str, Path],
+                         list_dir: bool = True,
+                         list_file: bool = True,
+                         suffix: Optional[Union[str, Tuple[str]]] = None,
+                         recursive: bool = False) -> Iterator[str]:
+        """Scan a directory to find the interested directories or files in
+        arbitrary order.
+
+        Note:
+            :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+        Args:
+            dir_path (str | Path): Path of the directory.
+            list_dir (bool): List the directories. Default: True.
+            list_file (bool): List the path of files. Default: True.
+            suffix (str or tuple[str], optional):  File suffix
+                that we are interested in. Default: None.
+            recursive (bool): If set to True, recursively scan the
+                directory. Default: False.
+
+        Yields:
+            Iterable[str]: A relative path to ``dir_path``.
+        """
+        if list_dir and suffix is not None:
+            raise TypeError('`suffix` should be None when `list_dir` is True')
+
+        if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+            raise TypeError('`suffix` must be a string or tuple of strings')
+
+        root = dir_path
+
+        def _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+                              recursive):
+            for entry in os.scandir(dir_path):
+                if not entry.name.startswith('.') and entry.is_file():
+                    rel_path = osp.relpath(entry.path, root)
+                    if (suffix is None
+                            or rel_path.endswith(suffix)) and list_file:
+                        yield rel_path
+                elif osp.isdir(entry.path):
+                    if list_dir:
+                        rel_dir = osp.relpath(entry.path, root)
+                        yield rel_dir
+                    if recursive:
+                        yield from _list_dir_or_file(entry.path, list_dir,
+                                                     list_file, suffix,
+                                                     recursive)
+
+        return _list_dir_or_file(dir_path, list_dir, list_file, suffix,
+                                 recursive)
+
+
+class HTTPBackend(BaseStorageBackend):
+    """HTTP and HTTPS storage bachend."""
+
+    def get(self, filepath):
+        value_buf = urlopen(filepath).read()
+        return value_buf
+
+    def get_text(self, filepath, encoding='utf-8'):
+        value_buf = urlopen(filepath).read()
+        return value_buf.decode(encoding)
+
+    @contextmanager
+    def get_local_path(self, filepath: str) -> Iterable[str]:
+        """Download a file from ``filepath``.
+
+        ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+        can be called with ``with`` statement, and when exists from the
+        ``with`` statement, the temporary path will be released.
+
+        Args:
+            filepath (str): Download a file from ``filepath``.
+
+        Examples:
+            >>> client = HTTPBackend()
+            >>> # After existing from the ``with`` clause,
+            >>> # the path will be removed
+            >>> with client.get_local_path('http://path/of/your/file') as path:
+            ...     # do something here
+        """
+        try:
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(self.get(filepath))
+            f.close()
+            yield f.name
+        finally:
+            os.remove(f.name)
+
+
+class FileClient:
+    """A general file client to access files in different backends.
+
+    The client loads a file or text in a specified backend from its path
+    and returns it as a binary or text file. There are two ways to choose a
+    backend, the name of backend and the prefix of path. Although both of them
+    can be used to choose a storage backend, ``backend`` has a higher priority
+    that is if they are all set, the storage backend will be chosen by the
+    backend argument. If they are all `None`, the disk backend will be chosen.
+    Note that It can also register other backend accessor with a given name,
+    prefixes, and backend class. In addition, We use the singleton pattern to
+    avoid repeated object creation. If the arguments are the same, the same
+    object will be returned.
+
+    Args:
+        backend (str, optional): The storage backend type. Options are "disk",
+            "ceph", "memcached", "lmdb", "http" and "petrel". Default: None.
+        prefix (str, optional): The prefix of the registered storage backend.
+            Options are "s3", "http", "https". Default: None.
+
+    Examples:
+        >>> # only set backend
+        >>> file_client = FileClient(backend='petrel')
+        >>> # only set prefix
+        >>> file_client = FileClient(prefix='s3')
+        >>> # set both backend and prefix but use backend to choose client
+        >>> file_client = FileClient(backend='petrel', prefix='s3')
+        >>> # if the arguments are the same, the same object is returned
+        >>> file_client1 = FileClient(backend='petrel')
+        >>> file_client1 is file_client
+        True
+
+    Attributes:
+        client (:obj:`BaseStorageBackend`): The backend object.
+    """
+
+    _backends = {
+        'disk': HardDiskBackend,
+        'ceph': CephBackend,
+        'memcached': MemcachedBackend,
+        'lmdb': LmdbBackend,
+        'petrel': PetrelBackend,
+        'http': HTTPBackend,
+    }
+    # This collection is used to record the overridden backends, and when a
+    # backend appears in the collection, the singleton pattern is disabled for
+    # that backend, because if the singleton pattern is used, then the object
+    # returned will be the backend before overwriting
+    _overridden_backends = set()
+    _prefix_to_backends = {
+        's3': PetrelBackend,
+        'http': HTTPBackend,
+        'https': HTTPBackend,
+    }
+    _overridden_prefixes = set()
+
+    _instances = {}
+
+    def __new__(cls, backend=None, prefix=None, **kwargs):
+        if backend is None and prefix is None:
+            backend = 'disk'
+        if backend is not None and backend not in cls._backends:
+            raise ValueError(
+                f'Backend {backend} is not supported. Currently supported ones'
+                f' are {list(cls._backends.keys())}')
+        if prefix is not None and prefix not in cls._prefix_to_backends:
+            raise ValueError(
+                f'prefix {prefix} is not supported. Currently supported ones '
+                f'are {list(cls._prefix_to_backends.keys())}')
+
+        # concatenate the arguments to a unique key for determining whether
+        # objects with the same arguments were created
+        arg_key = f'{backend}:{prefix}'
+        for key, value in kwargs.items():
+            arg_key += f':{key}:{value}'
+
+        # if a backend was overridden, it will create a new object
+        if (arg_key in cls._instances
+                and backend not in cls._overridden_backends
+                and prefix not in cls._overridden_prefixes):
+            _instance = cls._instances[arg_key]
+        else:
+            # create a new object and put it to _instance
+            _instance = super().__new__(cls)
+            if backend is not None:
+                _instance.client = cls._backends[backend](**kwargs)
+            else:
+                _instance.client = cls._prefix_to_backends[prefix](**kwargs)
+
+            cls._instances[arg_key] = _instance
+
+        return _instance
+
+    @property
+    def name(self):
+        return self.client.name
+
+    @property
+    def allow_symlink(self):
+        return self.client.allow_symlink
+
+    @staticmethod
+    def parse_uri_prefix(uri: Union[str, Path]) -> Optional[str]:
+        """Parse the prefix of a uri.
+
+        Args:
+            uri (str | Path): Uri to be parsed that contains the file prefix.
+
+        Examples:
+            >>> FileClient.parse_uri_prefix('s3://path/of/your/file')
+            's3'
+
+        Returns:
+            str | None: Return the prefix of uri if the uri contains '://'
+                else ``None``.
+        """
+        assert is_filepath(uri)
+        uri = str(uri)
+        if '://' not in uri:
+            return None
+        else:
+            prefix, _ = uri.split('://')
+            # In the case of PetrelBackend, the prefix may contains the cluster
+            # name like clusterName:s3
+            if ':' in prefix:
+                _, prefix = prefix.split(':')
+            return prefix
+
+    @classmethod
+    def infer_client(cls,
+                     file_client_args: Optional[dict] = None,
+                     uri: Optional[Union[str, Path]] = None) -> 'FileClient':
+        """Infer a suitable file client based on the URI and arguments.
+
+        Args:
+            file_client_args (dict, optional): Arguments to instantiate a
+                FileClient. Default: None.
+            uri (str | Path, optional): Uri to be parsed that contains the file
+                prefix. Default: None.
+
+        Examples:
+            >>> uri = 's3://path/of/your/file'
+            >>> file_client = FileClient.infer_client(uri=uri)
+            >>> file_client_args = {'backend': 'petrel'}
+            >>> file_client = FileClient.infer_client(file_client_args)
+
+        Returns:
+            FileClient: Instantiated FileClient object.
+        """
+        assert file_client_args is not None or uri is not None
+        if file_client_args is None:
+            file_prefix = cls.parse_uri_prefix(uri)  # type: ignore
+            return cls(prefix=file_prefix)
+        else:
+            return cls(**file_client_args)
+
+    @classmethod
+    def _register_backend(cls, name, backend, force=False, prefixes=None):
+        if not isinstance(name, str):
+            raise TypeError('the backend name should be a string, '
+                            f'but got {type(name)}')
+        if not inspect.isclass(backend):
+            raise TypeError(
+                f'backend should be a class but got {type(backend)}')
+        if not issubclass(backend, BaseStorageBackend):
+            raise TypeError(
+                f'backend {backend} is not a subclass of BaseStorageBackend')
+        if not force and name in cls._backends:
+            raise KeyError(
+                f'{name} is already registered as a storage backend, '
+                'add "force=True" if you want to override it')
+
+        if name in cls._backends and force:
+            cls._overridden_backends.add(name)
+        cls._backends[name] = backend
+
+        if prefixes is not None:
+            if isinstance(prefixes, str):
+                prefixes = [prefixes]
+            else:
+                assert isinstance(prefixes, (list, tuple))
+            for prefix in prefixes:
+                if prefix not in cls._prefix_to_backends:
+                    cls._prefix_to_backends[prefix] = backend
+                elif (prefix in cls._prefix_to_backends) and force:
+                    cls._overridden_prefixes.add(prefix)
+                    cls._prefix_to_backends[prefix] = backend
+                else:
+                    raise KeyError(
+                        f'{prefix} is already registered as a storage backend,'
+                        ' add "force=True" if you want to override it')
+
+    @classmethod
+    def register_backend(cls, name, backend=None, force=False, prefixes=None):
+        """Register a backend to FileClient.
+
+        This method can be used as a normal class method or a decorator.
+
+        .. code-block:: python
+
+            class NewBackend(BaseStorageBackend):
+
+                def get(self, filepath):
+                    return filepath
+
+                def get_text(self, filepath):
+                    return filepath
+
+            FileClient.register_backend('new', NewBackend)
+
+        or
+
+        .. code-block:: python
+
+            @FileClient.register_backend('new')
+            class NewBackend(BaseStorageBackend):
+
+                def get(self, filepath):
+                    return filepath
+
+                def get_text(self, filepath):
+                    return filepath
+
+        Args:
+            name (str): The name of the registered backend.
+            backend (class, optional): The backend class to be registered,
+                which must be a subclass of :class:`BaseStorageBackend`.
+                When this method is used as a decorator, backend is None.
+                Defaults to None.
+            force (bool, optional): Whether to override the backend if the name
+                has already been registered. Defaults to False.
+            prefixes (str or list[str] or tuple[str], optional): The prefixes
+                of the registered storage backend. Default: None.
+                `New in version 1.3.15.`
+        """
+        if backend is not None:
+            cls._register_backend(
+                name, backend, force=force, prefixes=prefixes)
+            return
+
+        def _register(backend_cls):
+            cls._register_backend(
+                name, backend_cls, force=force, prefixes=prefixes)
+            return backend_cls
+
+        return _register
+
+    def get(self, filepath: Union[str, Path]) -> Union[bytes, memoryview]:
+        """Read data from a given ``filepath`` with 'rb' mode.
+
+        Note:
+            There are two types of return values for ``get``, one is ``bytes``
+            and the other is ``memoryview``. The advantage of using memoryview
+            is that you can avoid copying, and if you want to convert it to
+            ``bytes``, you can use ``.tobytes()``.
+
+        Args:
+            filepath (str or Path): Path to read data.
+
+        Returns:
+            bytes | memoryview: Expected bytes object or a memory view of the
+                bytes object.
+        """
+        return self.client.get(filepath)
+
+    def get_text(self, filepath: Union[str, Path], encoding='utf-8') -> str:
+        """Read data from a given ``filepath`` with 'r' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+
+        Returns:
+            str: Expected text reading from ``filepath``.
+        """
+        return self.client.get_text(filepath, encoding)
+
+    def put(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        """Write data to a given ``filepath`` with 'wb' mode.
+
+        Note:
+            ``put`` should create a directory if the directory of ``filepath``
+            does not exist.
+
+        Args:
+            obj (bytes): Data to be written.
+            filepath (str or Path): Path to write data.
+        """
+        self.client.put(obj, filepath)
+
+    def put_text(self, obj: str, filepath: Union[str, Path]) -> None:
+        """Write data to a given ``filepath`` with 'w' mode.
+
+        Note:
+            ``put_text`` should create a directory if the directory of
+            ``filepath`` does not exist.
+
+        Args:
+            obj (str): Data to be written.
+            filepath (str or Path): Path to write data.
+            encoding (str, optional): The encoding format used to open the
+                `filepath`. Default: 'utf-8'.
+        """
+        self.client.put_text(obj, filepath)
+
+    def remove(self, filepath: Union[str, Path]) -> None:
+        """Remove a file.
+
+        Args:
+            filepath (str, Path): Path to be removed.
+        """
+        self.client.remove(filepath)
+
+    def exists(self, filepath: Union[str, Path]) -> bool:
+        """Check whether a file path exists.
+
+        Args:
+            filepath (str or Path): Path to be checked whether exists.
+
+        Returns:
+            bool: Return ``True`` if ``filepath`` exists, ``False`` otherwise.
+        """
+        return self.client.exists(filepath)
+
+    def isdir(self, filepath: Union[str, Path]) -> bool:
+        """Check whether a file path is a directory.
+
+        Args:
+            filepath (str or Path): Path to be checked whether it is a
+                directory.
+
+        Returns:
+            bool: Return ``True`` if ``filepath`` points to a directory,
+                ``False`` otherwise.
+        """
+        return self.client.isdir(filepath)
+
+    def isfile(self, filepath: Union[str, Path]) -> bool:
+        """Check whether a file path is a file.
+
+        Args:
+            filepath (str or Path): Path to be checked whether it is a file.
+
+        Returns:
+            bool: Return ``True`` if ``filepath`` points to a file, ``False``
+                otherwise.
+        """
+        return self.client.isfile(filepath)
+
+    def join_path(self, filepath: Union[str, Path],
+                  *filepaths: Union[str, Path]) -> str:
+        """Concatenate all file paths.
+
+        Join one or more filepath components intelligently. The return value
+        is the concatenation of filepath and any members of *filepaths.
+
+        Args:
+            filepath (str or Path): Path to be concatenated.
+
+        Returns:
+            str: The result of concatenation.
+        """
+        return self.client.join_path(filepath, *filepaths)
+
+    @contextmanager
+    def get_local_path(self, filepath: Union[str, Path]) -> Iterable[str]:
+        """Download data from ``filepath`` and write the data to local path.
+
+        ``get_local_path`` is decorated by :meth:`contxtlib.contextmanager`. It
+        can be called with ``with`` statement, and when exists from the
+        ``with`` statement, the temporary path will be released.
+
+        Note:
+            If the ``filepath`` is a local path, just return itself.
+
+        .. warning::
+            ``get_local_path`` is an experimental interface that may change in
+            the future.
+
+        Args:
+            filepath (str or Path): Path to be read data.
+
+        Examples:
+            >>> file_client = FileClient(prefix='s3')
+            >>> with file_client.get_local_path('s3://bucket/abc.jpg') as path:
+            ...     # do something here
+
+        Yields:
+            Iterable[str]: Only yield one path.
+        """
+        with self.client.get_local_path(str(filepath)) as local_path:
+            yield local_path
+
+    def list_dir_or_file(self,
+                         dir_path: Union[str, Path],
+                         list_dir: bool = True,
+                         list_file: bool = True,
+                         suffix: Optional[Union[str, Tuple[str]]] = None,
+                         recursive: bool = False) -> Iterator[str]:
+        """Scan a directory to find the interested directories or files in
+        arbitrary order.
+
+        Note:
+            :meth:`list_dir_or_file` returns the path relative to ``dir_path``.
+
+        Args:
+            dir_path (str | Path): Path of the directory.
+            list_dir (bool): List the directories. Default: True.
+            list_file (bool): List the path of files. Default: True.
+            suffix (str or tuple[str], optional):  File suffix
+                that we are interested in. Default: None.
+            recursive (bool): If set to True, recursively scan the
+                directory. Default: False.
+
+        Yields:
+            Iterable[str]: A relative path to ``dir_path``.
+        """
+        yield from self.client.list_dir_or_file(dir_path, list_dir, list_file,
+                                                suffix, recursive)
diff --git a/annotator/uniformer/mmcv/fileio/handlers/__init__.py b/annotator/uniformer/mmcv/fileio/handlers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa24d91972837b8756b225f4879bac20436eb72a
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import BaseFileHandler
+from .json_handler import JsonHandler
+from .pickle_handler import PickleHandler
+from .yaml_handler import YamlHandler
+
+__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
diff --git a/annotator/uniformer/mmcv/fileio/handlers/base.py b/annotator/uniformer/mmcv/fileio/handlers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..288878bc57282fbb2f12b32290152ca8e9d3cab0
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/base.py
@@ -0,0 +1,30 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABCMeta, abstractmethod
+
+
+class BaseFileHandler(metaclass=ABCMeta):
+    # `str_like` is a flag to indicate whether the type of file object is
+    # str-like object or bytes-like object. Pickle only processes bytes-like
+    # objects but json only processes str-like object. If it is str-like
+    # object, `StringIO` will be used to process the buffer.
+    str_like = True
+
+    @abstractmethod
+    def load_from_fileobj(self, file, **kwargs):
+        pass
+
+    @abstractmethod
+    def dump_to_fileobj(self, obj, file, **kwargs):
+        pass
+
+    @abstractmethod
+    def dump_to_str(self, obj, **kwargs):
+        pass
+
+    def load_from_path(self, filepath, mode='r', **kwargs):
+        with open(filepath, mode) as f:
+            return self.load_from_fileobj(f, **kwargs)
+
+    def dump_to_path(self, obj, filepath, mode='w', **kwargs):
+        with open(filepath, mode) as f:
+            self.dump_to_fileobj(obj, f, **kwargs)
diff --git a/annotator/uniformer/mmcv/fileio/handlers/json_handler.py b/annotator/uniformer/mmcv/fileio/handlers/json_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..18d4f15f74139d20adff18b20be5529c592a66b6
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/json_handler.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+
+import numpy as np
+
+from .base import BaseFileHandler
+
+
+def set_default(obj):
+    """Set default json values for non-serializable values.
+
+    It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
+    It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
+    etc.) into plain numbers of plain python built-in types.
+    """
+    if isinstance(obj, (set, range)):
+        return list(obj)
+    elif isinstance(obj, np.ndarray):
+        return obj.tolist()
+    elif isinstance(obj, np.generic):
+        return obj.item()
+    raise TypeError(f'{type(obj)} is unsupported for json dump')
+
+
+class JsonHandler(BaseFileHandler):
+
+    def load_from_fileobj(self, file):
+        return json.load(file)
+
+    def dump_to_fileobj(self, obj, file, **kwargs):
+        kwargs.setdefault('default', set_default)
+        json.dump(obj, file, **kwargs)
+
+    def dump_to_str(self, obj, **kwargs):
+        kwargs.setdefault('default', set_default)
+        return json.dumps(obj, **kwargs)
diff --git a/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py b/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b37c79bed4ef9fd8913715e62dbe3fc5cafdc3aa
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/pickle_handler.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pickle
+
+from .base import BaseFileHandler
+
+
+class PickleHandler(BaseFileHandler):
+
+    str_like = False
+
+    def load_from_fileobj(self, file, **kwargs):
+        return pickle.load(file, **kwargs)
+
+    def load_from_path(self, filepath, **kwargs):
+        return super(PickleHandler, self).load_from_path(
+            filepath, mode='rb', **kwargs)
+
+    def dump_to_str(self, obj, **kwargs):
+        kwargs.setdefault('protocol', 2)
+        return pickle.dumps(obj, **kwargs)
+
+    def dump_to_fileobj(self, obj, file, **kwargs):
+        kwargs.setdefault('protocol', 2)
+        pickle.dump(obj, file, **kwargs)
+
+    def dump_to_path(self, obj, filepath, **kwargs):
+        super(PickleHandler, self).dump_to_path(
+            obj, filepath, mode='wb', **kwargs)
diff --git a/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py b/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5aa2eea1e8c76f8baf753d1c8c959dee665e543
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/handlers/yaml_handler.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import yaml
+
+try:
+    from yaml import CLoader as Loader, CDumper as Dumper
+except ImportError:
+    from yaml import Loader, Dumper
+
+from .base import BaseFileHandler  # isort:skip
+
+
+class YamlHandler(BaseFileHandler):
+
+    def load_from_fileobj(self, file, **kwargs):
+        kwargs.setdefault('Loader', Loader)
+        return yaml.load(file, **kwargs)
+
+    def dump_to_fileobj(self, obj, file, **kwargs):
+        kwargs.setdefault('Dumper', Dumper)
+        yaml.dump(obj, file, **kwargs)
+
+    def dump_to_str(self, obj, **kwargs):
+        kwargs.setdefault('Dumper', Dumper)
+        return yaml.dump(obj, **kwargs)
diff --git a/annotator/uniformer/mmcv/fileio/io.py b/annotator/uniformer/mmcv/fileio/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaefde58aa3ea5b58f86249ce7e1c40c186eb8dd
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/io.py
@@ -0,0 +1,151 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from io import BytesIO, StringIO
+from pathlib import Path
+
+from ..utils import is_list_of, is_str
+from .file_client import FileClient
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+
+file_handlers = {
+    'json': JsonHandler(),
+    'yaml': YamlHandler(),
+    'yml': YamlHandler(),
+    'pickle': PickleHandler(),
+    'pkl': PickleHandler()
+}
+
+
+def load(file, file_format=None, file_client_args=None, **kwargs):
+    """Load data from json/yaml/pickle files.
+
+    This method provides a unified api for loading data from serialized files.
+
+    Note:
+        In v1.3.16 and later, ``load`` supports loading data from serialized
+        files those can be storaged in different backends.
+
+    Args:
+        file (str or :obj:`Path` or file-like object): Filename or a file-like
+            object.
+        file_format (str, optional): If not specified, the file format will be
+            inferred from the file extension, otherwise use the specified one.
+            Currently supported formats include "json", "yaml/yml" and
+            "pickle/pkl".
+        file_client_args (dict, optional): Arguments to instantiate a
+            FileClient. See :class:`mmcv.fileio.FileClient` for details.
+            Default: None.
+
+    Examples:
+        >>> load('/path/of/your/file')  # file is storaged in disk
+        >>> load('https://path/of/your/file')  # file is storaged in Internet
+        >>> load('s3://path/of/your/file')  # file is storaged in petrel
+
+    Returns:
+        The content from the file.
+    """
+    if isinstance(file, Path):
+        file = str(file)
+    if file_format is None and is_str(file):
+        file_format = file.split('.')[-1]
+    if file_format not in file_handlers:
+        raise TypeError(f'Unsupported format: {file_format}')
+
+    handler = file_handlers[file_format]
+    if is_str(file):
+        file_client = FileClient.infer_client(file_client_args, file)
+        if handler.str_like:
+            with StringIO(file_client.get_text(file)) as f:
+                obj = handler.load_from_fileobj(f, **kwargs)
+        else:
+            with BytesIO(file_client.get(file)) as f:
+                obj = handler.load_from_fileobj(f, **kwargs)
+    elif hasattr(file, 'read'):
+        obj = handler.load_from_fileobj(file, **kwargs)
+    else:
+        raise TypeError('"file" must be a filepath str or a file-object')
+    return obj
+
+
+def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
+    """Dump data to json/yaml/pickle strings or files.
+
+    This method provides a unified api for dumping data as strings or to files,
+    and also supports custom arguments for each file format.
+
+    Note:
+        In v1.3.16 and later, ``dump`` supports dumping data as strings or to
+        files which is saved to different backends.
+
+    Args:
+        obj (any): The python object to be dumped.
+        file (str or :obj:`Path` or file-like object, optional): If not
+            specified, then the object is dumped to a str, otherwise to a file
+            specified by the filename or file-like object.
+        file_format (str, optional): Same as :func:`load`.
+        file_client_args (dict, optional): Arguments to instantiate a
+            FileClient. See :class:`mmcv.fileio.FileClient` for details.
+            Default: None.
+
+    Examples:
+        >>> dump('hello world', '/path/of/your/file')  # disk
+        >>> dump('hello world', 's3://path/of/your/file')  # ceph or petrel
+
+    Returns:
+        bool: True for success, False otherwise.
+    """
+    if isinstance(file, Path):
+        file = str(file)
+    if file_format is None:
+        if is_str(file):
+            file_format = file.split('.')[-1]
+        elif file is None:
+            raise ValueError(
+                'file_format must be specified since file is None')
+    if file_format not in file_handlers:
+        raise TypeError(f'Unsupported format: {file_format}')
+
+    handler = file_handlers[file_format]
+    if file is None:
+        return handler.dump_to_str(obj, **kwargs)
+    elif is_str(file):
+        file_client = FileClient.infer_client(file_client_args, file)
+        if handler.str_like:
+            with StringIO() as f:
+                handler.dump_to_fileobj(obj, f, **kwargs)
+                file_client.put_text(f.getvalue(), file)
+        else:
+            with BytesIO() as f:
+                handler.dump_to_fileobj(obj, f, **kwargs)
+                file_client.put(f.getvalue(), file)
+    elif hasattr(file, 'write'):
+        handler.dump_to_fileobj(obj, file, **kwargs)
+    else:
+        raise TypeError('"file" must be a filename str or a file-object')
+
+
+def _register_handler(handler, file_formats):
+    """Register a handler for some file extensions.
+
+    Args:
+        handler (:obj:`BaseFileHandler`): Handler to be registered.
+        file_formats (str or list[str]): File formats to be handled by this
+            handler.
+    """
+    if not isinstance(handler, BaseFileHandler):
+        raise TypeError(
+            f'handler must be a child of BaseFileHandler, not {type(handler)}')
+    if isinstance(file_formats, str):
+        file_formats = [file_formats]
+    if not is_list_of(file_formats, str):
+        raise TypeError('file_formats must be a str or a list of str')
+    for ext in file_formats:
+        file_handlers[ext] = handler
+
+
+def register_handler(file_formats, **kwargs):
+
+    def wrap(cls):
+        _register_handler(cls(**kwargs), file_formats)
+        return cls
+
+    return wrap
diff --git a/annotator/uniformer/mmcv/fileio/parse.py b/annotator/uniformer/mmcv/fileio/parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60f0d611b8d75692221d0edd7dc993b0a6445c9
--- /dev/null
+++ b/annotator/uniformer/mmcv/fileio/parse.py
@@ -0,0 +1,97 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from io import StringIO
+
+from .file_client import FileClient
+
+
+def list_from_file(filename,
+                   prefix='',
+                   offset=0,
+                   max_num=0,
+                   encoding='utf-8',
+                   file_client_args=None):
+    """Load a text file and parse the content as a list of strings.
+
+    Note:
+        In v1.3.16 and later, ``list_from_file`` supports loading a text file
+        which can be storaged in different backends and parsing the content as
+        a list for strings.
+
+    Args:
+        filename (str): Filename.
+        prefix (str): The prefix to be inserted to the beginning of each item.
+        offset (int): The offset of lines.
+        max_num (int): The maximum number of lines to be read,
+            zeros and negatives mean no limitation.
+        encoding (str): Encoding used to open the file. Default utf-8.
+        file_client_args (dict, optional): Arguments to instantiate a
+            FileClient. See :class:`mmcv.fileio.FileClient` for details.
+            Default: None.
+
+    Examples:
+        >>> list_from_file('/path/of/your/file')  # disk
+        ['hello', 'world']
+        >>> list_from_file('s3://path/of/your/file')  # ceph or petrel
+        ['hello', 'world']
+
+    Returns:
+        list[str]: A list of strings.
+    """
+    cnt = 0
+    item_list = []
+    file_client = FileClient.infer_client(file_client_args, filename)
+    with StringIO(file_client.get_text(filename, encoding)) as f:
+        for _ in range(offset):
+            f.readline()
+        for line in f:
+            if 0 < max_num <= cnt:
+                break
+            item_list.append(prefix + line.rstrip('\n\r'))
+            cnt += 1
+    return item_list
+
+
+def dict_from_file(filename,
+                   key_type=str,
+                   encoding='utf-8',
+                   file_client_args=None):
+    """Load a text file and parse the content as a dict.
+
+    Each line of the text file will be two or more columns split by
+    whitespaces or tabs. The first column will be parsed as dict keys, and
+    the following columns will be parsed as dict values.
+
+    Note:
+        In v1.3.16 and later, ``dict_from_file`` supports loading a text file
+        which can be storaged in different backends and parsing the content as
+        a dict.
+
+    Args:
+        filename(str): Filename.
+        key_type(type): Type of the dict keys. str is user by default and
+            type conversion will be performed if specified.
+        encoding (str): Encoding used to open the file. Default utf-8.
+        file_client_args (dict, optional): Arguments to instantiate a
+            FileClient. See :class:`mmcv.fileio.FileClient` for details.
+            Default: None.
+
+    Examples:
+        >>> dict_from_file('/path/of/your/file')  # disk
+        {'key1': 'value1', 'key2': 'value2'}
+        >>> dict_from_file('s3://path/of/your/file')  # ceph or petrel
+        {'key1': 'value1', 'key2': 'value2'}
+
+    Returns:
+        dict: The parsed contents.
+    """
+    mapping = {}
+    file_client = FileClient.infer_client(file_client_args, filename)
+    with StringIO(file_client.get_text(filename, encoding)) as f:
+        for line in f:
+            items = line.rstrip('\n').split()
+            assert len(items) >= 2
+            key = key_type(items[0])
+            val = items[1:] if len(items) > 2 else items[1]
+            mapping[key] = val
+    return mapping
diff --git a/annotator/uniformer/mmcv/image/__init__.py b/annotator/uniformer/mmcv/image/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0051d609d3de4e7562e3fe638335c66617c4d91
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/__init__.py
@@ -0,0 +1,28 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .colorspace import (bgr2gray, bgr2hls, bgr2hsv, bgr2rgb, bgr2ycbcr,
+                         gray2bgr, gray2rgb, hls2bgr, hsv2bgr, imconvert,
+                         rgb2bgr, rgb2gray, rgb2ycbcr, ycbcr2bgr, ycbcr2rgb)
+from .geometric import (cutout, imcrop, imflip, imflip_, impad,
+                        impad_to_multiple, imrescale, imresize, imresize_like,
+                        imresize_to_multiple, imrotate, imshear, imtranslate,
+                        rescale_size)
+from .io import imfrombytes, imread, imwrite, supported_backends, use_backend
+from .misc import tensor2imgs
+from .photometric import (adjust_brightness, adjust_color, adjust_contrast,
+                          adjust_lighting, adjust_sharpness, auto_contrast,
+                          clahe, imdenormalize, imequalize, iminvert,
+                          imnormalize, imnormalize_, lut_transform, posterize,
+                          solarize)
+
+__all__ = [
+    'bgr2gray', 'bgr2hls', 'bgr2hsv', 'bgr2rgb', 'gray2bgr', 'gray2rgb',
+    'hls2bgr', 'hsv2bgr', 'imconvert', 'rgb2bgr', 'rgb2gray', 'imrescale',
+    'imresize', 'imresize_like', 'imresize_to_multiple', 'rescale_size',
+    'imcrop', 'imflip', 'imflip_', 'impad', 'impad_to_multiple', 'imrotate',
+    'imfrombytes', 'imread', 'imwrite', 'supported_backends', 'use_backend',
+    'imdenormalize', 'imnormalize', 'imnormalize_', 'iminvert', 'posterize',
+    'solarize', 'rgb2ycbcr', 'bgr2ycbcr', 'ycbcr2rgb', 'ycbcr2bgr',
+    'tensor2imgs', 'imshear', 'imtranslate', 'adjust_color', 'imequalize',
+    'adjust_brightness', 'adjust_contrast', 'lut_transform', 'clahe',
+    'adjust_sharpness', 'auto_contrast', 'cutout', 'adjust_lighting'
+]
diff --git a/annotator/uniformer/mmcv/image/colorspace.py b/annotator/uniformer/mmcv/image/colorspace.py
new file mode 100644
index 0000000000000000000000000000000000000000..814533952fdfda23d67cb6a3073692d8c1156add
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/colorspace.py
@@ -0,0 +1,306 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+
+def imconvert(img, src, dst):
+    """Convert an image from the src colorspace to dst colorspace.
+
+    Args:
+        img (ndarray): The input image.
+        src (str): The source colorspace, e.g., 'rgb', 'hsv'.
+        dst (str): The destination colorspace, e.g., 'rgb', 'hsv'.
+
+    Returns:
+        ndarray: The converted image.
+    """
+    code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+    out_img = cv2.cvtColor(img, code)
+    return out_img
+
+
+def bgr2gray(img, keepdim=False):
+    """Convert a BGR image to grayscale image.
+
+    Args:
+        img (ndarray): The input image.
+        keepdim (bool): If False (by default), then return the grayscale image
+            with 2 dims, otherwise 3 dims.
+
+    Returns:
+        ndarray: The converted grayscale image.
+    """
+    out_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+    if keepdim:
+        out_img = out_img[..., None]
+    return out_img
+
+
+def rgb2gray(img, keepdim=False):
+    """Convert a RGB image to grayscale image.
+
+    Args:
+        img (ndarray): The input image.
+        keepdim (bool): If False (by default), then return the grayscale image
+            with 2 dims, otherwise 3 dims.
+
+    Returns:
+        ndarray: The converted grayscale image.
+    """
+    out_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+    if keepdim:
+        out_img = out_img[..., None]
+    return out_img
+
+
+def gray2bgr(img):
+    """Convert a grayscale image to BGR image.
+
+    Args:
+        img (ndarray): The input image.
+
+    Returns:
+        ndarray: The converted BGR image.
+    """
+    img = img[..., None] if img.ndim == 2 else img
+    out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+    return out_img
+
+
+def gray2rgb(img):
+    """Convert a grayscale image to RGB image.
+
+    Args:
+        img (ndarray): The input image.
+
+    Returns:
+        ndarray: The converted RGB image.
+    """
+    img = img[..., None] if img.ndim == 2 else img
+    out_img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+    return out_img
+
+
+def _convert_input_type_range(img):
+    """Convert the type and range of the input image.
+
+    It converts the input image to np.float32 type and range of [0, 1].
+    It is mainly used for pre-processing the input image in colorspace
+    conversion functions such as rgb2ycbcr and ycbcr2rgb.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+
+    Returns:
+        (ndarray): The converted image with type of np.float32 and range of
+            [0, 1].
+    """
+    img_type = img.dtype
+    img = img.astype(np.float32)
+    if img_type == np.float32:
+        pass
+    elif img_type == np.uint8:
+        img /= 255.
+    else:
+        raise TypeError('The img type should be np.float32 or np.uint8, '
+                        f'but got {img_type}')
+    return img
+
+
+def _convert_output_type_range(img, dst_type):
+    """Convert the type and range of the image according to dst_type.
+
+    It converts the image to desired type and range. If `dst_type` is np.uint8,
+    images will be converted to np.uint8 type with range [0, 255]. If
+    `dst_type` is np.float32, it converts the image to np.float32 type with
+    range [0, 1].
+    It is mainly used for post-processing images in colorspace conversion
+    functions such as rgb2ycbcr and ycbcr2rgb.
+
+    Args:
+        img (ndarray): The image to be converted with np.float32 type and
+            range [0, 255].
+        dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+            converts the image to np.uint8 type with range [0, 255]. If
+            dst_type is np.float32, it converts the image to np.float32 type
+            with range [0, 1].
+
+    Returns:
+        (ndarray): The converted image with desired type and range.
+    """
+    if dst_type not in (np.uint8, np.float32):
+        raise TypeError('The dst_type should be np.float32 or np.uint8, '
+                        f'but got {dst_type}')
+    if dst_type == np.uint8:
+        img = img.round()
+    else:
+        img /= 255.
+    return img.astype(dst_type)
+
+
+def rgb2ycbcr(img, y_only=False):
+    """Convert a RGB image to YCbCr image.
+
+    This function produces the same results as Matlab's `rgb2ycbcr` function.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+        y_only (bool): Whether to only return Y channel. Default: False.
+
+    Returns:
+        ndarray: The converted YCbCr image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img)
+    if y_only:
+        out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+    else:
+        out_img = np.matmul(
+            img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+                  [24.966, 112.0, -18.214]]) + [16, 128, 128]
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+    """Convert a BGR image to YCbCr image.
+
+    The bgr version of rgb2ycbcr.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+        y_only (bool): Whether to only return Y channel. Default: False.
+
+    Returns:
+        ndarray: The converted YCbCr image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img)
+    if y_only:
+        out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+    else:
+        out_img = np.matmul(
+            img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+                  [65.481, -37.797, 112.0]]) + [16, 128, 128]
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def ycbcr2rgb(img):
+    """Convert a YCbCr image to RGB image.
+
+    This function produces the same results as Matlab's ycbcr2rgb function.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+
+    Returns:
+        ndarray: The converted RGB image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img) * 255
+    out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+                              [0, -0.00153632, 0.00791071],
+                              [0.00625893, -0.00318811, 0]]) * 255.0 + [
+                                  -222.921, 135.576, -276.836
+                              ]
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def ycbcr2bgr(img):
+    """Convert a YCbCr image to BGR image.
+
+    The bgr version of ycbcr2rgb.
+    It implements the ITU-R BT.601 conversion for standard-definition
+    television. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+    It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+    In OpenCV, it implements a JPEG conversion. See more details in
+    https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+    Args:
+        img (ndarray): The input image. It accepts:
+            1. np.uint8 type with range [0, 255];
+            2. np.float32 type with range [0, 1].
+
+    Returns:
+        ndarray: The converted BGR image. The output image has the same type
+            and range as input image.
+    """
+    img_type = img.dtype
+    img = _convert_input_type_range(img) * 255
+    out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621],
+                              [0.00791071, -0.00153632, 0],
+                              [0, -0.00318811, 0.00625893]]) * 255.0 + [
+                                  -276.836, 135.576, -222.921
+                              ]
+    out_img = _convert_output_type_range(out_img, img_type)
+    return out_img
+
+
+def convert_color_factory(src, dst):
+
+    code = getattr(cv2, f'COLOR_{src.upper()}2{dst.upper()}')
+
+    def convert_color(img):
+        out_img = cv2.cvtColor(img, code)
+        return out_img
+
+    convert_color.__doc__ = f"""Convert a {src.upper()} image to {dst.upper()}
+        image.
+
+    Args:
+        img (ndarray or str): The input image.
+
+    Returns:
+        ndarray: The converted {dst.upper()} image.
+    """
+
+    return convert_color
+
+
+bgr2rgb = convert_color_factory('bgr', 'rgb')
+
+rgb2bgr = convert_color_factory('rgb', 'bgr')
+
+bgr2hsv = convert_color_factory('bgr', 'hsv')
+
+hsv2bgr = convert_color_factory('hsv', 'bgr')
+
+bgr2hls = convert_color_factory('bgr', 'hls')
+
+hls2bgr = convert_color_factory('hls', 'bgr')
diff --git a/annotator/uniformer/mmcv/image/geometric.py b/annotator/uniformer/mmcv/image/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf97c201cb4e43796c911919d03fb26a07ed817d
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/geometric.py
@@ -0,0 +1,728 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+
+import cv2
+import numpy as np
+
+from ..utils import to_2tuple
+from .io import imread_backend
+
+try:
+    from PIL import Image
+except ImportError:
+    Image = None
+
+
+def _scale_size(size, scale):
+    """Rescale a size by a ratio.
+
+    Args:
+        size (tuple[int]): (w, h).
+        scale (float | tuple(float)): Scaling factor.
+
+    Returns:
+        tuple[int]: scaled size.
+    """
+    if isinstance(scale, (float, int)):
+        scale = (scale, scale)
+    w, h = size
+    return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5)
+
+
+cv2_interp_codes = {
+    'nearest': cv2.INTER_NEAREST,
+    'bilinear': cv2.INTER_LINEAR,
+    'bicubic': cv2.INTER_CUBIC,
+    'area': cv2.INTER_AREA,
+    'lanczos': cv2.INTER_LANCZOS4
+}
+
+if Image is not None:
+    pillow_interp_codes = {
+        'nearest': Image.NEAREST,
+        'bilinear': Image.BILINEAR,
+        'bicubic': Image.BICUBIC,
+        'box': Image.BOX,
+        'lanczos': Image.LANCZOS,
+        'hamming': Image.HAMMING
+    }
+
+
+def imresize(img,
+             size,
+             return_scale=False,
+             interpolation='bilinear',
+             out=None,
+             backend=None):
+    """Resize image to a given size.
+
+    Args:
+        img (ndarray): The input image.
+        size (tuple[int]): Target size (w, h).
+        return_scale (bool): Whether to return `w_scale` and `h_scale`.
+        interpolation (str): Interpolation method, accepted values are
+            "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+            backend, "nearest", "bilinear" for 'pillow' backend.
+        out (ndarray): The output destination.
+        backend (str | None): The image resize backend type. Options are `cv2`,
+            `pillow`, `None`. If backend is None, the global imread_backend
+            specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+    Returns:
+        tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+            `resized_img`.
+    """
+    h, w = img.shape[:2]
+    if backend is None:
+        backend = imread_backend
+    if backend not in ['cv2', 'pillow']:
+        raise ValueError(f'backend: {backend} is not supported for resize.'
+                         f"Supported backends are 'cv2', 'pillow'")
+
+    if backend == 'pillow':
+        assert img.dtype == np.uint8, 'Pillow backend only support uint8 type'
+        pil_image = Image.fromarray(img)
+        pil_image = pil_image.resize(size, pillow_interp_codes[interpolation])
+        resized_img = np.array(pil_image)
+    else:
+        resized_img = cv2.resize(
+            img, size, dst=out, interpolation=cv2_interp_codes[interpolation])
+    if not return_scale:
+        return resized_img
+    else:
+        w_scale = size[0] / w
+        h_scale = size[1] / h
+        return resized_img, w_scale, h_scale
+
+
+def imresize_to_multiple(img,
+                         divisor,
+                         size=None,
+                         scale_factor=None,
+                         keep_ratio=False,
+                         return_scale=False,
+                         interpolation='bilinear',
+                         out=None,
+                         backend=None):
+    """Resize image according to a given size or scale factor and then rounds
+    up the the resized or rescaled image size to the nearest value that can be
+    divided by the divisor.
+
+    Args:
+        img (ndarray): The input image.
+        divisor (int | tuple): Resized image size will be a multiple of
+            divisor. If divisor is a tuple, divisor should be
+            (w_divisor, h_divisor).
+        size (None | int | tuple[int]): Target size (w, h). Default: None.
+        scale_factor (None | float | tuple[float]): Multiplier for spatial
+            size. Should match input size if it is a tuple and the 2D style is
+            (w_scale_factor, h_scale_factor). Default: None.
+        keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+            image. Default: False.
+        return_scale (bool): Whether to return `w_scale` and `h_scale`.
+        interpolation (str): Interpolation method, accepted values are
+            "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
+            backend, "nearest", "bilinear" for 'pillow' backend.
+        out (ndarray): The output destination.
+        backend (str | None): The image resize backend type. Options are `cv2`,
+            `pillow`, `None`. If backend is None, the global imread_backend
+            specified by ``mmcv.use_backend()`` will be used. Default: None.
+
+    Returns:
+        tuple | ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+            `resized_img`.
+    """
+    h, w = img.shape[:2]
+    if size is not None and scale_factor is not None:
+        raise ValueError('only one of size or scale_factor should be defined')
+    elif size is None and scale_factor is None:
+        raise ValueError('one of size or scale_factor should be defined')
+    elif size is not None:
+        size = to_2tuple(size)
+        if keep_ratio:
+            size = rescale_size((w, h), size, return_scale=False)
+    else:
+        size = _scale_size((w, h), scale_factor)
+
+    divisor = to_2tuple(divisor)
+    size = tuple([int(np.ceil(s / d)) * d for s, d in zip(size, divisor)])
+    resized_img, w_scale, h_scale = imresize(
+        img,
+        size,
+        return_scale=True,
+        interpolation=interpolation,
+        out=out,
+        backend=backend)
+    if return_scale:
+        return resized_img, w_scale, h_scale
+    else:
+        return resized_img
+
+
+def imresize_like(img,
+                  dst_img,
+                  return_scale=False,
+                  interpolation='bilinear',
+                  backend=None):
+    """Resize image to the same size of a given image.
+
+    Args:
+        img (ndarray): The input image.
+        dst_img (ndarray): The target image.
+        return_scale (bool): Whether to return `w_scale` and `h_scale`.
+        interpolation (str): Same as :func:`resize`.
+        backend (str | None): Same as :func:`resize`.
+
+    Returns:
+        tuple or ndarray: (`resized_img`, `w_scale`, `h_scale`) or
+            `resized_img`.
+    """
+    h, w = dst_img.shape[:2]
+    return imresize(img, (w, h), return_scale, interpolation, backend=backend)
+
+
+def rescale_size(old_size, scale, return_scale=False):
+    """Calculate the new size to be rescaled to.
+
+    Args:
+        old_size (tuple[int]): The old size (w, h) of image.
+        scale (float | tuple[int]): The scaling factor or maximum size.
+            If it is a float number, then the image will be rescaled by this
+            factor, else if it is a tuple of 2 integers, then the image will
+            be rescaled as large as possible within the scale.
+        return_scale (bool): Whether to return the scaling factor besides the
+            rescaled image size.
+
+    Returns:
+        tuple[int]: The new rescaled image size.
+    """
+    w, h = old_size
+    if isinstance(scale, (float, int)):
+        if scale <= 0:
+            raise ValueError(f'Invalid scale {scale}, must be positive.')
+        scale_factor = scale
+    elif isinstance(scale, tuple):
+        max_long_edge = max(scale)
+        max_short_edge = min(scale)
+        scale_factor = min(max_long_edge / max(h, w),
+                           max_short_edge / min(h, w))
+    else:
+        raise TypeError(
+            f'Scale must be a number or tuple of int, but got {type(scale)}')
+
+    new_size = _scale_size((w, h), scale_factor)
+
+    if return_scale:
+        return new_size, scale_factor
+    else:
+        return new_size
+
+
+def imrescale(img,
+              scale,
+              return_scale=False,
+              interpolation='bilinear',
+              backend=None):
+    """Resize image while keeping the aspect ratio.
+
+    Args:
+        img (ndarray): The input image.
+        scale (float | tuple[int]): The scaling factor or maximum size.
+            If it is a float number, then the image will be rescaled by this
+            factor, else if it is a tuple of 2 integers, then the image will
+            be rescaled as large as possible within the scale.
+        return_scale (bool): Whether to return the scaling factor besides the
+            rescaled image.
+        interpolation (str): Same as :func:`resize`.
+        backend (str | None): Same as :func:`resize`.
+
+    Returns:
+        ndarray: The rescaled image.
+    """
+    h, w = img.shape[:2]
+    new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
+    rescaled_img = imresize(
+        img, new_size, interpolation=interpolation, backend=backend)
+    if return_scale:
+        return rescaled_img, scale_factor
+    else:
+        return rescaled_img
+
+
+def imflip(img, direction='horizontal'):
+    """Flip an image horizontally or vertically.
+
+    Args:
+        img (ndarray): Image to be flipped.
+        direction (str): The flip direction, either "horizontal" or
+            "vertical" or "diagonal".
+
+    Returns:
+        ndarray: The flipped image.
+    """
+    assert direction in ['horizontal', 'vertical', 'diagonal']
+    if direction == 'horizontal':
+        return np.flip(img, axis=1)
+    elif direction == 'vertical':
+        return np.flip(img, axis=0)
+    else:
+        return np.flip(img, axis=(0, 1))
+
+
+def imflip_(img, direction='horizontal'):
+    """Inplace flip an image horizontally or vertically.
+
+    Args:
+        img (ndarray): Image to be flipped.
+        direction (str): The flip direction, either "horizontal" or
+            "vertical" or "diagonal".
+
+    Returns:
+        ndarray: The flipped image (inplace).
+    """
+    assert direction in ['horizontal', 'vertical', 'diagonal']
+    if direction == 'horizontal':
+        return cv2.flip(img, 1, img)
+    elif direction == 'vertical':
+        return cv2.flip(img, 0, img)
+    else:
+        return cv2.flip(img, -1, img)
+
+
+def imrotate(img,
+             angle,
+             center=None,
+             scale=1.0,
+             border_value=0,
+             interpolation='bilinear',
+             auto_bound=False):
+    """Rotate an image.
+
+    Args:
+        img (ndarray): Image to be rotated.
+        angle (float): Rotation angle in degrees, positive values mean
+            clockwise rotation.
+        center (tuple[float], optional): Center point (w, h) of the rotation in
+            the source image. If not specified, the center of the image will be
+            used.
+        scale (float): Isotropic scale factor.
+        border_value (int): Border value.
+        interpolation (str): Same as :func:`resize`.
+        auto_bound (bool): Whether to adjust the image size to cover the whole
+            rotated image.
+
+    Returns:
+        ndarray: The rotated image.
+    """
+    if center is not None and auto_bound:
+        raise ValueError('`auto_bound` conflicts with `center`')
+    h, w = img.shape[:2]
+    if center is None:
+        center = ((w - 1) * 0.5, (h - 1) * 0.5)
+    assert isinstance(center, tuple)
+
+    matrix = cv2.getRotationMatrix2D(center, -angle, scale)
+    if auto_bound:
+        cos = np.abs(matrix[0, 0])
+        sin = np.abs(matrix[0, 1])
+        new_w = h * sin + w * cos
+        new_h = h * cos + w * sin
+        matrix[0, 2] += (new_w - w) * 0.5
+        matrix[1, 2] += (new_h - h) * 0.5
+        w = int(np.round(new_w))
+        h = int(np.round(new_h))
+    rotated = cv2.warpAffine(
+        img,
+        matrix, (w, h),
+        flags=cv2_interp_codes[interpolation],
+        borderValue=border_value)
+    return rotated
+
+
+def bbox_clip(bboxes, img_shape):
+    """Clip bboxes to fit the image shape.
+
+    Args:
+        bboxes (ndarray): Shape (..., 4*k)
+        img_shape (tuple[int]): (height, width) of the image.
+
+    Returns:
+        ndarray: Clipped bboxes.
+    """
+    assert bboxes.shape[-1] % 4 == 0
+    cmin = np.empty(bboxes.shape[-1], dtype=bboxes.dtype)
+    cmin[0::2] = img_shape[1] - 1
+    cmin[1::2] = img_shape[0] - 1
+    clipped_bboxes = np.maximum(np.minimum(bboxes, cmin), 0)
+    return clipped_bboxes
+
+
+def bbox_scaling(bboxes, scale, clip_shape=None):
+    """Scaling bboxes w.r.t the box center.
+
+    Args:
+        bboxes (ndarray): Shape(..., 4).
+        scale (float): Scaling factor.
+        clip_shape (tuple[int], optional): If specified, bboxes that exceed the
+            boundary will be clipped according to the given shape (h, w).
+
+    Returns:
+        ndarray: Scaled bboxes.
+    """
+    if float(scale) == 1.0:
+        scaled_bboxes = bboxes.copy()
+    else:
+        w = bboxes[..., 2] - bboxes[..., 0] + 1
+        h = bboxes[..., 3] - bboxes[..., 1] + 1
+        dw = (w * (scale - 1)) * 0.5
+        dh = (h * (scale - 1)) * 0.5
+        scaled_bboxes = bboxes + np.stack((-dw, -dh, dw, dh), axis=-1)
+    if clip_shape is not None:
+        return bbox_clip(scaled_bboxes, clip_shape)
+    else:
+        return scaled_bboxes
+
+
+def imcrop(img, bboxes, scale=1.0, pad_fill=None):
+    """Crop image patches.
+
+    3 steps: scale the bboxes -> clip bboxes -> crop and pad.
+
+    Args:
+        img (ndarray): Image to be cropped.
+        bboxes (ndarray): Shape (k, 4) or (4, ), location of cropped bboxes.
+        scale (float, optional): Scale ratio of bboxes, the default value
+            1.0 means no padding.
+        pad_fill (Number | list[Number]): Value to be filled for padding.
+            Default: None, which means no padding.
+
+    Returns:
+        list[ndarray] | ndarray: The cropped image patches.
+    """
+    chn = 1 if img.ndim == 2 else img.shape[2]
+    if pad_fill is not None:
+        if isinstance(pad_fill, (int, float)):
+            pad_fill = [pad_fill for _ in range(chn)]
+        assert len(pad_fill) == chn
+
+    _bboxes = bboxes[None, ...] if bboxes.ndim == 1 else bboxes
+    scaled_bboxes = bbox_scaling(_bboxes, scale).astype(np.int32)
+    clipped_bbox = bbox_clip(scaled_bboxes, img.shape)
+
+    patches = []
+    for i in range(clipped_bbox.shape[0]):
+        x1, y1, x2, y2 = tuple(clipped_bbox[i, :])
+        if pad_fill is None:
+            patch = img[y1:y2 + 1, x1:x2 + 1, ...]
+        else:
+            _x1, _y1, _x2, _y2 = tuple(scaled_bboxes[i, :])
+            if chn == 1:
+                patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1)
+            else:
+                patch_shape = (_y2 - _y1 + 1, _x2 - _x1 + 1, chn)
+            patch = np.array(
+                pad_fill, dtype=img.dtype) * np.ones(
+                    patch_shape, dtype=img.dtype)
+            x_start = 0 if _x1 >= 0 else -_x1
+            y_start = 0 if _y1 >= 0 else -_y1
+            w = x2 - x1 + 1
+            h = y2 - y1 + 1
+            patch[y_start:y_start + h, x_start:x_start + w,
+                  ...] = img[y1:y1 + h, x1:x1 + w, ...]
+        patches.append(patch)
+
+    if bboxes.ndim == 1:
+        return patches[0]
+    else:
+        return patches
+
+
+def impad(img,
+          *,
+          shape=None,
+          padding=None,
+          pad_val=0,
+          padding_mode='constant'):
+    """Pad the given image to a certain shape or pad on all sides with
+    specified padding mode and padding value.
+
+    Args:
+        img (ndarray): Image to be padded.
+        shape (tuple[int]): Expected padding shape (h, w). Default: None.
+        padding (int or tuple[int]): Padding on each border. If a single int is
+            provided this is used to pad all borders. If tuple of length 2 is
+            provided this is the padding on left/right and top/bottom
+            respectively. If a tuple of length 4 is provided this is the
+            padding for the left, top, right and bottom borders respectively.
+            Default: None. Note that `shape` and `padding` can not be both
+            set.
+        pad_val (Number | Sequence[Number]): Values to be filled in padding
+            areas when padding_mode is 'constant'. Default: 0.
+        padding_mode (str): Type of padding. Should be: constant, edge,
+            reflect or symmetric. Default: constant.
+
+            - constant: pads with a constant value, this value is specified
+                with pad_val.
+            - edge: pads with the last value at the edge of the image.
+            - reflect: pads with reflection of image without repeating the
+                last value on the edge. For example, padding [1, 2, 3, 4]
+                with 2 elements on both sides in reflect mode will result
+                in [3, 2, 1, 2, 3, 4, 3, 2].
+            - symmetric: pads with reflection of image repeating the last
+                value on the edge. For example, padding [1, 2, 3, 4] with
+                2 elements on both sides in symmetric mode will result in
+                [2, 1, 1, 2, 3, 4, 4, 3]
+
+    Returns:
+        ndarray: The padded image.
+    """
+
+    assert (shape is not None) ^ (padding is not None)
+    if shape is not None:
+        padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
+
+    # check pad_val
+    if isinstance(pad_val, tuple):
+        assert len(pad_val) == img.shape[-1]
+    elif not isinstance(pad_val, numbers.Number):
+        raise TypeError('pad_val must be a int or a tuple. '
+                        f'But received {type(pad_val)}')
+
+    # check padding
+    if isinstance(padding, tuple) and len(padding) in [2, 4]:
+        if len(padding) == 2:
+            padding = (padding[0], padding[1], padding[0], padding[1])
+    elif isinstance(padding, numbers.Number):
+        padding = (padding, padding, padding, padding)
+    else:
+        raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
+                         f'But received {padding}')
+
+    # check padding mode
+    assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
+
+    border_type = {
+        'constant': cv2.BORDER_CONSTANT,
+        'edge': cv2.BORDER_REPLICATE,
+        'reflect': cv2.BORDER_REFLECT_101,
+        'symmetric': cv2.BORDER_REFLECT
+    }
+    img = cv2.copyMakeBorder(
+        img,
+        padding[1],
+        padding[3],
+        padding[0],
+        padding[2],
+        border_type[padding_mode],
+        value=pad_val)
+
+    return img
+
+
+def impad_to_multiple(img, divisor, pad_val=0):
+    """Pad an image to ensure each edge to be multiple to some number.
+
+    Args:
+        img (ndarray): Image to be padded.
+        divisor (int): Padded image edges will be multiple to divisor.
+        pad_val (Number | Sequence[Number]): Same as :func:`impad`.
+
+    Returns:
+        ndarray: The padded image.
+    """
+    pad_h = int(np.ceil(img.shape[0] / divisor)) * divisor
+    pad_w = int(np.ceil(img.shape[1] / divisor)) * divisor
+    return impad(img, shape=(pad_h, pad_w), pad_val=pad_val)
+
+
+def cutout(img, shape, pad_val=0):
+    """Randomly cut out a rectangle from the original img.
+
+    Args:
+        img (ndarray): Image to be cutout.
+        shape (int | tuple[int]): Expected cutout shape (h, w). If given as a
+            int, the value will be used for both h and w.
+        pad_val (int | float | tuple[int | float]): Values to be filled in the
+            cut area. Defaults to 0.
+
+    Returns:
+        ndarray: The cutout image.
+    """
+
+    channels = 1 if img.ndim == 2 else img.shape[2]
+    if isinstance(shape, int):
+        cut_h, cut_w = shape, shape
+    else:
+        assert isinstance(shape, tuple) and len(shape) == 2, \
+            f'shape must be a int or a tuple with length 2, but got type ' \
+            f'{type(shape)} instead.'
+        cut_h, cut_w = shape
+    if isinstance(pad_val, (int, float)):
+        pad_val = tuple([pad_val] * channels)
+    elif isinstance(pad_val, tuple):
+        assert len(pad_val) == channels, \
+            'Expected the num of elements in tuple equals the channels' \
+            'of input image. Found {} vs {}'.format(
+                len(pad_val), channels)
+    else:
+        raise TypeError(f'Invalid type {type(pad_val)} for `pad_val`')
+
+    img_h, img_w = img.shape[:2]
+    y0 = np.random.uniform(img_h)
+    x0 = np.random.uniform(img_w)
+
+    y1 = int(max(0, y0 - cut_h / 2.))
+    x1 = int(max(0, x0 - cut_w / 2.))
+    y2 = min(img_h, y1 + cut_h)
+    x2 = min(img_w, x1 + cut_w)
+
+    if img.ndim == 2:
+        patch_shape = (y2 - y1, x2 - x1)
+    else:
+        patch_shape = (y2 - y1, x2 - x1, channels)
+
+    img_cutout = img.copy()
+    patch = np.array(
+        pad_val, dtype=img.dtype) * np.ones(
+            patch_shape, dtype=img.dtype)
+    img_cutout[y1:y2, x1:x2, ...] = patch
+
+    return img_cutout
+
+
+def _get_shear_matrix(magnitude, direction='horizontal'):
+    """Generate the shear matrix for transformation.
+
+    Args:
+        magnitude (int | float): The magnitude used for shear.
+        direction (str): The flip direction, either "horizontal"
+            or "vertical".
+
+    Returns:
+        ndarray: The shear matrix with dtype float32.
+    """
+    if direction == 'horizontal':
+        shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0]])
+    elif direction == 'vertical':
+        shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0]])
+    return shear_matrix
+
+
+def imshear(img,
+            magnitude,
+            direction='horizontal',
+            border_value=0,
+            interpolation='bilinear'):
+    """Shear an image.
+
+    Args:
+        img (ndarray): Image to be sheared with format (h, w)
+            or (h, w, c).
+        magnitude (int | float): The magnitude used for shear.
+        direction (str): The flip direction, either "horizontal"
+            or "vertical".
+        border_value (int | tuple[int]): Value used in case of a
+            constant border.
+        interpolation (str): Same as :func:`resize`.
+
+    Returns:
+        ndarray: The sheared image.
+    """
+    assert direction in ['horizontal',
+                         'vertical'], f'Invalid direction: {direction}'
+    height, width = img.shape[:2]
+    if img.ndim == 2:
+        channels = 1
+    elif img.ndim == 3:
+        channels = img.shape[-1]
+    if isinstance(border_value, int):
+        border_value = tuple([border_value] * channels)
+    elif isinstance(border_value, tuple):
+        assert len(border_value) == channels, \
+            'Expected the num of elements in tuple equals the channels' \
+            'of input image. Found {} vs {}'.format(
+                len(border_value), channels)
+    else:
+        raise ValueError(
+            f'Invalid type {type(border_value)} for `border_value`')
+    shear_matrix = _get_shear_matrix(magnitude, direction)
+    sheared = cv2.warpAffine(
+        img,
+        shear_matrix,
+        (width, height),
+        # Note case when the number elements in `border_value`
+        # greater than 3 (e.g. shearing masks whose channels large
+        # than 3) will raise TypeError in `cv2.warpAffine`.
+        # Here simply slice the first 3 values in `border_value`.
+        borderValue=border_value[:3],
+        flags=cv2_interp_codes[interpolation])
+    return sheared
+
+
+def _get_translate_matrix(offset, direction='horizontal'):
+    """Generate the translate matrix.
+
+    Args:
+        offset (int | float): The offset used for translate.
+        direction (str): The translate direction, either
+            "horizontal" or "vertical".
+
+    Returns:
+        ndarray: The translate matrix with dtype float32.
+    """
+    if direction == 'horizontal':
+        translate_matrix = np.float32([[1, 0, offset], [0, 1, 0]])
+    elif direction == 'vertical':
+        translate_matrix = np.float32([[1, 0, 0], [0, 1, offset]])
+    return translate_matrix
+
+
+def imtranslate(img,
+                offset,
+                direction='horizontal',
+                border_value=0,
+                interpolation='bilinear'):
+    """Translate an image.
+
+    Args:
+        img (ndarray): Image to be translated with format
+            (h, w) or (h, w, c).
+        offset (int | float): The offset used for translate.
+        direction (str): The translate direction, either "horizontal"
+            or "vertical".
+        border_value (int | tuple[int]): Value used in case of a
+            constant border.
+        interpolation (str): Same as :func:`resize`.
+
+    Returns:
+        ndarray: The translated image.
+    """
+    assert direction in ['horizontal',
+                         'vertical'], f'Invalid direction: {direction}'
+    height, width = img.shape[:2]
+    if img.ndim == 2:
+        channels = 1
+    elif img.ndim == 3:
+        channels = img.shape[-1]
+    if isinstance(border_value, int):
+        border_value = tuple([border_value] * channels)
+    elif isinstance(border_value, tuple):
+        assert len(border_value) == channels, \
+            'Expected the num of elements in tuple equals the channels' \
+            'of input image. Found {} vs {}'.format(
+                len(border_value), channels)
+    else:
+        raise ValueError(
+            f'Invalid type {type(border_value)} for `border_value`.')
+    translate_matrix = _get_translate_matrix(offset, direction)
+    translated = cv2.warpAffine(
+        img,
+        translate_matrix,
+        (width, height),
+        # Note case when the number elements in `border_value`
+        # greater than 3 (e.g. translating masks whose channels
+        # large than 3) will raise TypeError in `cv2.warpAffine`.
+        # Here simply slice the first 3 values in `border_value`.
+        borderValue=border_value[:3],
+        flags=cv2_interp_codes[interpolation])
+    return translated
diff --git a/annotator/uniformer/mmcv/image/io.py b/annotator/uniformer/mmcv/image/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3fa2e8cc06b1a7b0b69de6406980b15d61a1e5d
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/io.py
@@ -0,0 +1,258 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os.path as osp
+from pathlib import Path
+
+import cv2
+import numpy as np
+from cv2 import (IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_IGNORE_ORIENTATION,
+                 IMREAD_UNCHANGED)
+
+from annotator.uniformer.mmcv.utils import check_file_exist, is_str, mkdir_or_exist
+
+try:
+    from turbojpeg import TJCS_RGB, TJPF_BGR, TJPF_GRAY, TurboJPEG
+except ImportError:
+    TJCS_RGB = TJPF_GRAY = TJPF_BGR = TurboJPEG = None
+
+try:
+    from PIL import Image, ImageOps
+except ImportError:
+    Image = None
+
+try:
+    import tifffile
+except ImportError:
+    tifffile = None
+
+jpeg = None
+supported_backends = ['cv2', 'turbojpeg', 'pillow', 'tifffile']
+
+imread_flags = {
+    'color': IMREAD_COLOR,
+    'grayscale': IMREAD_GRAYSCALE,
+    'unchanged': IMREAD_UNCHANGED,
+    'color_ignore_orientation': IMREAD_IGNORE_ORIENTATION | IMREAD_COLOR,
+    'grayscale_ignore_orientation':
+    IMREAD_IGNORE_ORIENTATION | IMREAD_GRAYSCALE
+}
+
+imread_backend = 'cv2'
+
+
+def use_backend(backend):
+    """Select a backend for image decoding.
+
+    Args:
+        backend (str): The image decoding backend type. Options are `cv2`,
+        `pillow`, `turbojpeg` (see https://github.com/lilohuang/PyTurboJPEG)
+        and `tifffile`. `turbojpeg` is faster but it only supports `.jpeg`
+        file format.
+    """
+    assert backend in supported_backends
+    global imread_backend
+    imread_backend = backend
+    if imread_backend == 'turbojpeg':
+        if TurboJPEG is None:
+            raise ImportError('`PyTurboJPEG` is not installed')
+        global jpeg
+        if jpeg is None:
+            jpeg = TurboJPEG()
+    elif imread_backend == 'pillow':
+        if Image is None:
+            raise ImportError('`Pillow` is not installed')
+    elif imread_backend == 'tifffile':
+        if tifffile is None:
+            raise ImportError('`tifffile` is not installed')
+
+
+def _jpegflag(flag='color', channel_order='bgr'):
+    channel_order = channel_order.lower()
+    if channel_order not in ['rgb', 'bgr']:
+        raise ValueError('channel order must be either "rgb" or "bgr"')
+
+    if flag == 'color':
+        if channel_order == 'bgr':
+            return TJPF_BGR
+        elif channel_order == 'rgb':
+            return TJCS_RGB
+    elif flag == 'grayscale':
+        return TJPF_GRAY
+    else:
+        raise ValueError('flag must be "color" or "grayscale"')
+
+
+def _pillow2array(img, flag='color', channel_order='bgr'):
+    """Convert a pillow image to numpy array.
+
+    Args:
+        img (:obj:`PIL.Image.Image`): The image loaded using PIL
+        flag (str): Flags specifying the color type of a loaded image,
+            candidates are 'color', 'grayscale' and 'unchanged'.
+            Default to 'color'.
+        channel_order (str): The channel order of the output image array,
+            candidates are 'bgr' and 'rgb'. Default to 'bgr'.
+
+    Returns:
+        np.ndarray: The converted numpy array
+    """
+    channel_order = channel_order.lower()
+    if channel_order not in ['rgb', 'bgr']:
+        raise ValueError('channel order must be either "rgb" or "bgr"')
+
+    if flag == 'unchanged':
+        array = np.array(img)
+        if array.ndim >= 3 and array.shape[2] >= 3:  # color image
+            array[:, :, :3] = array[:, :, (2, 1, 0)]  # RGB to BGR
+    else:
+        # Handle exif orientation tag
+        if flag in ['color', 'grayscale']:
+            img = ImageOps.exif_transpose(img)
+        # If the image mode is not 'RGB', convert it to 'RGB' first.
+        if img.mode != 'RGB':
+            if img.mode != 'LA':
+                # Most formats except 'LA' can be directly converted to RGB
+                img = img.convert('RGB')
+            else:
+                # When the mode is 'LA', the default conversion will fill in
+                #  the canvas with black, which sometimes shadows black objects
+                #  in the foreground.
+                #
+                # Therefore, a random color (124, 117, 104) is used for canvas
+                img_rgba = img.convert('RGBA')
+                img = Image.new('RGB', img_rgba.size, (124, 117, 104))
+                img.paste(img_rgba, mask=img_rgba.split()[3])  # 3 is alpha
+        if flag in ['color', 'color_ignore_orientation']:
+            array = np.array(img)
+            if channel_order != 'rgb':
+                array = array[:, :, ::-1]  # RGB to BGR
+        elif flag in ['grayscale', 'grayscale_ignore_orientation']:
+            img = img.convert('L')
+            array = np.array(img)
+        else:
+            raise ValueError(
+                'flag must be "color", "grayscale", "unchanged", '
+                f'"color_ignore_orientation" or "grayscale_ignore_orientation"'
+                f' but got {flag}')
+    return array
+
+
+def imread(img_or_path, flag='color', channel_order='bgr', backend=None):
+    """Read an image.
+
+    Args:
+        img_or_path (ndarray or str or Path): Either a numpy array or str or
+            pathlib.Path. If it is a numpy array (loaded image), then
+            it will be returned as is.
+        flag (str): Flags specifying the color type of a loaded image,
+            candidates are `color`, `grayscale`, `unchanged`,
+            `color_ignore_orientation` and `grayscale_ignore_orientation`.
+            By default, `cv2` and `pillow` backend would rotate the image
+            according to its EXIF info unless called with `unchanged` or
+            `*_ignore_orientation` flags. `turbojpeg` and `tifffile` backend
+            always ignore image's EXIF info regardless of the flag.
+            The `turbojpeg` backend only supports `color` and `grayscale`.
+        channel_order (str): Order of channel, candidates are `bgr` and `rgb`.
+        backend (str | None): The image decoding backend type. Options are
+            `cv2`, `pillow`, `turbojpeg`, `tifffile`, `None`.
+            If backend is None, the global imread_backend specified by
+            ``mmcv.use_backend()`` will be used. Default: None.
+
+    Returns:
+        ndarray: Loaded image array.
+    """
+
+    if backend is None:
+        backend = imread_backend
+    if backend not in supported_backends:
+        raise ValueError(f'backend: {backend} is not supported. Supported '
+                         "backends are 'cv2', 'turbojpeg', 'pillow'")
+    if isinstance(img_or_path, Path):
+        img_or_path = str(img_or_path)
+
+    if isinstance(img_or_path, np.ndarray):
+        return img_or_path
+    elif is_str(img_or_path):
+        check_file_exist(img_or_path,
+                         f'img file does not exist: {img_or_path}')
+        if backend == 'turbojpeg':
+            with open(img_or_path, 'rb') as in_file:
+                img = jpeg.decode(in_file.read(),
+                                  _jpegflag(flag, channel_order))
+                if img.shape[-1] == 1:
+                    img = img[:, :, 0]
+            return img
+        elif backend == 'pillow':
+            img = Image.open(img_or_path)
+            img = _pillow2array(img, flag, channel_order)
+            return img
+        elif backend == 'tifffile':
+            img = tifffile.imread(img_or_path)
+            return img
+        else:
+            flag = imread_flags[flag] if is_str(flag) else flag
+            img = cv2.imread(img_or_path, flag)
+            if flag == IMREAD_COLOR and channel_order == 'rgb':
+                cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+            return img
+    else:
+        raise TypeError('"img" must be a numpy array or a str or '
+                        'a pathlib.Path object')
+
+
+def imfrombytes(content, flag='color', channel_order='bgr', backend=None):
+    """Read an image from bytes.
+
+    Args:
+        content (bytes): Image bytes got from files or other streams.
+        flag (str): Same as :func:`imread`.
+        backend (str | None): The image decoding backend type. Options are
+            `cv2`, `pillow`, `turbojpeg`, `None`. If backend is None, the
+            global imread_backend specified by ``mmcv.use_backend()`` will be
+            used. Default: None.
+
+    Returns:
+        ndarray: Loaded image array.
+    """
+
+    if backend is None:
+        backend = imread_backend
+    if backend not in supported_backends:
+        raise ValueError(f'backend: {backend} is not supported. Supported '
+                         "backends are 'cv2', 'turbojpeg', 'pillow'")
+    if backend == 'turbojpeg':
+        img = jpeg.decode(content, _jpegflag(flag, channel_order))
+        if img.shape[-1] == 1:
+            img = img[:, :, 0]
+        return img
+    elif backend == 'pillow':
+        buff = io.BytesIO(content)
+        img = Image.open(buff)
+        img = _pillow2array(img, flag, channel_order)
+        return img
+    else:
+        img_np = np.frombuffer(content, np.uint8)
+        flag = imread_flags[flag] if is_str(flag) else flag
+        img = cv2.imdecode(img_np, flag)
+        if flag == IMREAD_COLOR and channel_order == 'rgb':
+            cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
+        return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+    """Write image to file.
+
+    Args:
+        img (ndarray): Image array to be written.
+        file_path (str): Image file path.
+        params (None or list): Same as opencv :func:`imwrite` interface.
+        auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+            whether to create it automatically.
+
+    Returns:
+        bool: Successful or not.
+    """
+    if auto_mkdir:
+        dir_name = osp.abspath(osp.dirname(file_path))
+        mkdir_or_exist(dir_name)
+    return cv2.imwrite(file_path, img, params)
diff --git a/annotator/uniformer/mmcv/image/misc.py b/annotator/uniformer/mmcv/image/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e61f05e3b05e4c7b40de4eb6c8eb100e6da41d0
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/misc.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+
+import annotator.uniformer.mmcv as mmcv
+
+try:
+    import torch
+except ImportError:
+    torch = None
+
+
+def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
+    """Convert tensor to 3-channel images.
+
+    Args:
+        tensor (torch.Tensor): Tensor that contains multiple images, shape (
+            N, C, H, W).
+        mean (tuple[float], optional): Mean of images. Defaults to (0, 0, 0).
+        std (tuple[float], optional): Standard deviation of images.
+            Defaults to (1, 1, 1).
+        to_rgb (bool, optional): Whether the tensor was converted to RGB
+            format in the first place. If so, convert it back to BGR.
+            Defaults to True.
+
+    Returns:
+        list[np.ndarray]: A list that contains multiple images.
+    """
+
+    if torch is None:
+        raise RuntimeError('pytorch is not installed')
+    assert torch.is_tensor(tensor) and tensor.ndim == 4
+    assert len(mean) == 3
+    assert len(std) == 3
+
+    num_imgs = tensor.size(0)
+    mean = np.array(mean, dtype=np.float32)
+    std = np.array(std, dtype=np.float32)
+    imgs = []
+    for img_id in range(num_imgs):
+        img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
+        img = mmcv.imdenormalize(
+            img, mean, std, to_bgr=to_rgb).astype(np.uint8)
+        imgs.append(np.ascontiguousarray(img))
+    return imgs
diff --git a/annotator/uniformer/mmcv/image/photometric.py b/annotator/uniformer/mmcv/image/photometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..5085d012019c0cbf56f66f421a378278c1a058ae
--- /dev/null
+++ b/annotator/uniformer/mmcv/image/photometric.py
@@ -0,0 +1,428 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from ..utils import is_tuple_of
+from .colorspace import bgr2gray, gray2bgr
+
+
+def imnormalize(img, mean, std, to_rgb=True):
+    """Normalize an image with mean and std.
+
+    Args:
+        img (ndarray): Image to be normalized.
+        mean (ndarray): The mean to be used for normalize.
+        std (ndarray): The std to be used for normalize.
+        to_rgb (bool): Whether to convert to rgb.
+
+    Returns:
+        ndarray: The normalized image.
+    """
+    img = img.copy().astype(np.float32)
+    return imnormalize_(img, mean, std, to_rgb)
+
+
+def imnormalize_(img, mean, std, to_rgb=True):
+    """Inplace normalize an image with mean and std.
+
+    Args:
+        img (ndarray): Image to be normalized.
+        mean (ndarray): The mean to be used for normalize.
+        std (ndarray): The std to be used for normalize.
+        to_rgb (bool): Whether to convert to rgb.
+
+    Returns:
+        ndarray: The normalized image.
+    """
+    # cv2 inplace normalization does not accept uint8
+    assert img.dtype != np.uint8
+    mean = np.float64(mean.reshape(1, -1))
+    stdinv = 1 / np.float64(std.reshape(1, -1))
+    if to_rgb:
+        cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)  # inplace
+    cv2.subtract(img, mean, img)  # inplace
+    cv2.multiply(img, stdinv, img)  # inplace
+    return img
+
+
+def imdenormalize(img, mean, std, to_bgr=True):
+    assert img.dtype != np.uint8
+    mean = mean.reshape(1, -1).astype(np.float64)
+    std = std.reshape(1, -1).astype(np.float64)
+    img = cv2.multiply(img, std)  # make a copy
+    cv2.add(img, mean, img)  # inplace
+    if to_bgr:
+        cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img)  # inplace
+    return img
+
+
+def iminvert(img):
+    """Invert (negate) an image.
+
+    Args:
+        img (ndarray): Image to be inverted.
+
+    Returns:
+        ndarray: The inverted image.
+    """
+    return np.full_like(img, 255) - img
+
+
+def solarize(img, thr=128):
+    """Solarize an image (invert all pixel values above a threshold)
+
+    Args:
+        img (ndarray): Image to be solarized.
+        thr (int): Threshold for solarizing (0 - 255).
+
+    Returns:
+        ndarray: The solarized image.
+    """
+    img = np.where(img < thr, img, 255 - img)
+    return img
+
+
+def posterize(img, bits):
+    """Posterize an image (reduce the number of bits for each color channel)
+
+    Args:
+        img (ndarray): Image to be posterized.
+        bits (int): Number of bits (1 to 8) to use for posterizing.
+
+    Returns:
+        ndarray: The posterized image.
+    """
+    shift = 8 - bits
+    img = np.left_shift(np.right_shift(img, shift), shift)
+    return img
+
+
+def adjust_color(img, alpha=1, beta=None, gamma=0):
+    r"""It blends the source image and its gray image:
+
+    .. math::
+        output = img * alpha + gray\_img * beta + gamma
+
+    Args:
+        img (ndarray): The input source image.
+        alpha (int | float): Weight for the source image. Default 1.
+        beta (int | float): Weight for the converted gray image.
+            If None, it's assigned the value (1 - `alpha`).
+        gamma (int | float): Scalar added to each sum.
+            Same as :func:`cv2.addWeighted`. Default 0.
+
+    Returns:
+        ndarray: Colored image which has the same size and dtype as input.
+    """
+    gray_img = bgr2gray(img)
+    gray_img = np.tile(gray_img[..., None], [1, 1, 3])
+    if beta is None:
+        beta = 1 - alpha
+    colored_img = cv2.addWeighted(img, alpha, gray_img, beta, gamma)
+    if not colored_img.dtype == np.uint8:
+        # Note when the dtype of `img` is not the default `np.uint8`
+        # (e.g. np.float32), the value in `colored_img` got from cv2
+        # is not guaranteed to be in range [0, 255], so here clip
+        # is needed.
+        colored_img = np.clip(colored_img, 0, 255)
+    return colored_img
+
+
+def imequalize(img):
+    """Equalize the image histogram.
+
+    This function applies a non-linear mapping to the input image,
+    in order to create a uniform distribution of grayscale values
+    in the output image.
+
+    Args:
+        img (ndarray): Image to be equalized.
+
+    Returns:
+        ndarray: The equalized image.
+    """
+
+    def _scale_channel(im, c):
+        """Scale the data in the corresponding channel."""
+        im = im[:, :, c]
+        # Compute the histogram of the image channel.
+        histo = np.histogram(im, 256, (0, 255))[0]
+        # For computing the step, filter out the nonzeros.
+        nonzero_histo = histo[histo > 0]
+        step = (np.sum(nonzero_histo) - nonzero_histo[-1]) // 255
+        if not step:
+            lut = np.array(range(256))
+        else:
+            # Compute the cumulative sum, shifted by step // 2
+            # and then normalized by step.
+            lut = (np.cumsum(histo) + (step // 2)) // step
+            # Shift lut, prepending with 0.
+            lut = np.concatenate([[0], lut[:-1]], 0)
+            # handle potential integer overflow
+            lut[lut > 255] = 255
+        # If step is zero, return the original image.
+        # Otherwise, index from lut.
+        return np.where(np.equal(step, 0), im, lut[im])
+
+    # Scales each channel independently and then stacks
+    # the result.
+    s1 = _scale_channel(img, 0)
+    s2 = _scale_channel(img, 1)
+    s3 = _scale_channel(img, 2)
+    equalized_img = np.stack([s1, s2, s3], axis=-1)
+    return equalized_img.astype(img.dtype)
+
+
+def adjust_brightness(img, factor=1.):
+    """Adjust image brightness.
+
+    This function controls the brightness of an image. An
+    enhancement factor of 0.0 gives a black image.
+    A factor of 1.0 gives the original image. This function
+    blends the source image and the degenerated black image:
+
+    .. math::
+        output = img * factor + degenerated * (1 - factor)
+
+    Args:
+        img (ndarray): Image to be brightened.
+        factor (float): A value controls the enhancement.
+            Factor 1.0 returns the original image, lower
+            factors mean less color (brightness, contrast,
+            etc), and higher values more. Default 1.
+
+    Returns:
+        ndarray: The brightened image.
+    """
+    degenerated = np.zeros_like(img)
+    # Note manually convert the dtype to np.float32, to
+    # achieve as close results as PIL.ImageEnhance.Brightness.
+    # Set beta=1-factor, and gamma=0
+    brightened_img = cv2.addWeighted(
+        img.astype(np.float32), factor, degenerated.astype(np.float32),
+        1 - factor, 0)
+    brightened_img = np.clip(brightened_img, 0, 255)
+    return brightened_img.astype(img.dtype)
+
+
+def adjust_contrast(img, factor=1.):
+    """Adjust image contrast.
+
+    This function controls the contrast of an image. An
+    enhancement factor of 0.0 gives a solid grey
+    image. A factor of 1.0 gives the original image. It
+    blends the source image and the degenerated mean image:
+
+    .. math::
+        output = img * factor + degenerated * (1 - factor)
+
+    Args:
+        img (ndarray): Image to be contrasted. BGR order.
+        factor (float): Same as :func:`mmcv.adjust_brightness`.
+
+    Returns:
+        ndarray: The contrasted image.
+    """
+    gray_img = bgr2gray(img)
+    hist = np.histogram(gray_img, 256, (0, 255))[0]
+    mean = round(np.sum(gray_img) / np.sum(hist))
+    degenerated = (np.ones_like(img[..., 0]) * mean).astype(img.dtype)
+    degenerated = gray2bgr(degenerated)
+    contrasted_img = cv2.addWeighted(
+        img.astype(np.float32), factor, degenerated.astype(np.float32),
+        1 - factor, 0)
+    contrasted_img = np.clip(contrasted_img, 0, 255)
+    return contrasted_img.astype(img.dtype)
+
+
+def auto_contrast(img, cutoff=0):
+    """Auto adjust image contrast.
+
+    This function maximize (normalize) image contrast by first removing cutoff
+    percent of the lightest and darkest pixels from the histogram and remapping
+    the image so that the darkest pixel becomes black (0), and the lightest
+    becomes white (255).
+
+    Args:
+        img (ndarray): Image to be contrasted. BGR order.
+        cutoff (int | float | tuple): The cutoff percent of the lightest and
+            darkest pixels to be removed. If given as tuple, it shall be
+            (low, high). Otherwise, the single value will be used for both.
+            Defaults to 0.
+
+    Returns:
+        ndarray: The contrasted image.
+    """
+
+    def _auto_contrast_channel(im, c, cutoff):
+        im = im[:, :, c]
+        # Compute the histogram of the image channel.
+        histo = np.histogram(im, 256, (0, 255))[0]
+        # Remove cut-off percent pixels from histo
+        histo_sum = np.cumsum(histo)
+        cut_low = histo_sum[-1] * cutoff[0] // 100
+        cut_high = histo_sum[-1] - histo_sum[-1] * cutoff[1] // 100
+        histo_sum = np.clip(histo_sum, cut_low, cut_high) - cut_low
+        histo = np.concatenate([[histo_sum[0]], np.diff(histo_sum)], 0)
+
+        # Compute mapping
+        low, high = np.nonzero(histo)[0][0], np.nonzero(histo)[0][-1]
+        # If all the values have been cut off, return the origin img
+        if low >= high:
+            return im
+        scale = 255.0 / (high - low)
+        offset = -low * scale
+        lut = np.array(range(256))
+        lut = lut * scale + offset
+        lut = np.clip(lut, 0, 255)
+        return lut[im]
+
+    if isinstance(cutoff, (int, float)):
+        cutoff = (cutoff, cutoff)
+    else:
+        assert isinstance(cutoff, tuple), 'cutoff must be of type int, ' \
+            f'float or tuple, but got {type(cutoff)} instead.'
+    # Auto adjusts contrast for each channel independently and then stacks
+    # the result.
+    s1 = _auto_contrast_channel(img, 0, cutoff)
+    s2 = _auto_contrast_channel(img, 1, cutoff)
+    s3 = _auto_contrast_channel(img, 2, cutoff)
+    contrasted_img = np.stack([s1, s2, s3], axis=-1)
+    return contrasted_img.astype(img.dtype)
+
+
+def adjust_sharpness(img, factor=1., kernel=None):
+    """Adjust image sharpness.
+
+    This function controls the sharpness of an image. An
+    enhancement factor of 0.0 gives a blurred image. A
+    factor of 1.0 gives the original image. And a factor
+    of 2.0 gives a sharpened image. It blends the source
+    image and the degenerated mean image:
+
+    .. math::
+        output = img * factor + degenerated * (1 - factor)
+
+    Args:
+        img (ndarray): Image to be sharpened. BGR order.
+        factor (float): Same as :func:`mmcv.adjust_brightness`.
+        kernel (np.ndarray, optional): Filter kernel to be applied on the img
+            to obtain the degenerated img. Defaults to None.
+
+    Note:
+        No value sanity check is enforced on the kernel set by users. So with
+        an inappropriate kernel, the ``adjust_sharpness`` may fail to perform
+        the function its name indicates but end up performing whatever
+        transform determined by the kernel.
+
+    Returns:
+        ndarray: The sharpened image.
+    """
+
+    if kernel is None:
+        # adopted from PIL.ImageFilter.SMOOTH
+        kernel = np.array([[1., 1., 1.], [1., 5., 1.], [1., 1., 1.]]) / 13
+    assert isinstance(kernel, np.ndarray), \
+        f'kernel must be of type np.ndarray, but got {type(kernel)} instead.'
+    assert kernel.ndim == 2, \
+        f'kernel must have a dimension of 2, but got {kernel.ndim} instead.'
+
+    degenerated = cv2.filter2D(img, -1, kernel)
+    sharpened_img = cv2.addWeighted(
+        img.astype(np.float32), factor, degenerated.astype(np.float32),
+        1 - factor, 0)
+    sharpened_img = np.clip(sharpened_img, 0, 255)
+    return sharpened_img.astype(img.dtype)
+
+
+def adjust_lighting(img, eigval, eigvec, alphastd=0.1, to_rgb=True):
+    """AlexNet-style PCA jitter.
+
+    This data augmentation is proposed in `ImageNet Classification with Deep
+    Convolutional Neural Networks
+    <https://dl.acm.org/doi/pdf/10.1145/3065386>`_.
+
+    Args:
+        img (ndarray): Image to be adjusted lighting. BGR order.
+        eigval (ndarray): the eigenvalue of the convariance matrix of pixel
+            values, respectively.
+        eigvec (ndarray): the eigenvector of the convariance matrix of pixel
+            values, respectively.
+        alphastd (float): The standard deviation for distribution of alpha.
+            Defaults to 0.1
+        to_rgb (bool): Whether to convert img to rgb.
+
+    Returns:
+        ndarray: The adjusted image.
+    """
+    assert isinstance(eigval, np.ndarray) and isinstance(eigvec, np.ndarray), \
+        f'eigval and eigvec should both be of type np.ndarray, got ' \
+        f'{type(eigval)} and {type(eigvec)} instead.'
+
+    assert eigval.ndim == 1 and eigvec.ndim == 2
+    assert eigvec.shape == (3, eigval.shape[0])
+    n_eigval = eigval.shape[0]
+    assert isinstance(alphastd, float), 'alphastd should be of type float, ' \
+        f'got {type(alphastd)} instead.'
+
+    img = img.copy().astype(np.float32)
+    if to_rgb:
+        cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)  # inplace
+
+    alpha = np.random.normal(0, alphastd, n_eigval)
+    alter = eigvec \
+        * np.broadcast_to(alpha.reshape(1, n_eigval), (3, n_eigval)) \
+        * np.broadcast_to(eigval.reshape(1, n_eigval), (3, n_eigval))
+    alter = np.broadcast_to(alter.sum(axis=1).reshape(1, 1, 3), img.shape)
+    img_adjusted = img + alter
+    return img_adjusted
+
+
+def lut_transform(img, lut_table):
+    """Transform array by look-up table.
+
+    The function lut_transform fills the output array with values from the
+    look-up table. Indices of the entries are taken from the input array.
+
+    Args:
+        img (ndarray): Image to be transformed.
+        lut_table (ndarray): look-up table of 256 elements; in case of
+            multi-channel input array, the table should either have a single
+            channel (in this case the same table is used for all channels) or
+            the same number of channels as in the input array.
+
+    Returns:
+        ndarray: The transformed image.
+    """
+    assert isinstance(img, np.ndarray)
+    assert 0 <= np.min(img) and np.max(img) <= 255
+    assert isinstance(lut_table, np.ndarray)
+    assert lut_table.shape == (256, )
+
+    return cv2.LUT(np.array(img, dtype=np.uint8), lut_table)
+
+
+def clahe(img, clip_limit=40.0, tile_grid_size=(8, 8)):
+    """Use CLAHE method to process the image.
+
+    See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+    Graphics Gems, 1994:474-485.` for more information.
+
+    Args:
+        img (ndarray): Image to be processed.
+        clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+        tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+            Input image will be divided into equally sized rectangular tiles.
+            It defines the number of tiles in row and column. Default: (8, 8).
+
+    Returns:
+        ndarray: The processed image.
+    """
+    assert isinstance(img, np.ndarray)
+    assert img.ndim == 2
+    assert isinstance(clip_limit, (float, int))
+    assert is_tuple_of(tile_grid_size, int)
+    assert len(tile_grid_size) == 2
+
+    clahe = cv2.createCLAHE(clip_limit, tile_grid_size)
+    return clahe.apply(np.array(img, dtype=np.uint8))
diff --git a/annotator/uniformer/mmcv/model_zoo/deprecated.json b/annotator/uniformer/mmcv/model_zoo/deprecated.json
new file mode 100644
index 0000000000000000000000000000000000000000..25cf6f28caecc22a77e3136fefa6b8dfc0e6cb5b
--- /dev/null
+++ b/annotator/uniformer/mmcv/model_zoo/deprecated.json
@@ -0,0 +1,6 @@
+{
+  "resnet50_caffe": "detectron/resnet50_caffe",
+  "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
+  "resnet101_caffe": "detectron/resnet101_caffe",
+  "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
+}
diff --git a/annotator/uniformer/mmcv/model_zoo/mmcls.json b/annotator/uniformer/mmcv/model_zoo/mmcls.json
new file mode 100644
index 0000000000000000000000000000000000000000..bdb311d9fe6d9f317290feedc9e37236c6cf6e8f
--- /dev/null
+++ b/annotator/uniformer/mmcv/model_zoo/mmcls.json
@@ -0,0 +1,31 @@
+{
+  "vgg11": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_batch256_imagenet_20210208-4271cd6c.pth",
+  "vgg13": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_batch256_imagenet_20210208-4d1d6080.pth",
+  "vgg16": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_batch256_imagenet_20210208-db26f1a5.pth",
+  "vgg19": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_batch256_imagenet_20210208-e6920e4a.pth",
+  "vgg11_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg11_bn_batch256_imagenet_20210207-f244902c.pth",
+  "vgg13_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg13_bn_batch256_imagenet_20210207-1a8b7864.pth",
+  "vgg16_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg16_bn_batch256_imagenet_20210208-7e55cd29.pth",
+  "vgg19_bn": "https://download.openmmlab.com/mmclassification/v0/vgg/vgg19_bn_batch256_imagenet_20210208-da620c4f.pth",
+  "resnet18": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_batch256_imagenet_20200708-34ab8f90.pth",
+  "resnet34": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_batch256_imagenet_20200708-32ffb4f7.pth",
+  "resnet50": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_batch256_imagenet_20200708-cfb998bf.pth",
+  "resnet101": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_batch256_imagenet_20200708-753f3608.pth",
+  "resnet152": "https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_batch256_imagenet_20200708-ec25b1f9.pth",
+  "resnet50_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d50_batch256_imagenet_20200708-1ad0ce94.pth",
+  "resnet101_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d101_batch256_imagenet_20200708-9cb302ef.pth",
+  "resnet152_v1d": "https://download.openmmlab.com/mmclassification/v0/resnet/resnetv1d152_batch256_imagenet_20200708-e79cb6a2.pth",
+  "resnext50_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext50_32x4d_b32x8_imagenet_20210429-56066e27.pth",
+  "resnext101_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x4d_b32x8_imagenet_20210506-e0fa3dd5.pth",
+  "resnext101_32x8d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext101_32x8d_b32x8_imagenet_20210506-23a247d5.pth",
+  "resnext152_32x4d": "https://download.openmmlab.com/mmclassification/v0/resnext/resnext152_32x4d_b32x8_imagenet_20210524-927787be.pth",
+  "se-resnet50": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet50_batch256_imagenet_20200804-ae206104.pth",
+  "se-resnet101": "https://download.openmmlab.com/mmclassification/v0/se-resnet/se-resnet101_batch256_imagenet_20200804-ba5b51d4.pth",
+  "resnest50": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest50_imagenet_converted-1ebf0afe.pth",
+  "resnest101": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest101_imagenet_converted-032caa52.pth",
+  "resnest200": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest200_imagenet_converted-581a60f2.pth",
+  "resnest269": "https://download.openmmlab.com/mmclassification/v0/resnest/resnest269_imagenet_converted-59930960.pth",
+  "shufflenet_v1": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v1/shufflenet_v1_batch1024_imagenet_20200804-5d6cec73.pth",
+  "shufflenet_v2": "https://download.openmmlab.com/mmclassification/v0/shufflenet_v2/shufflenet_v2_batch1024_imagenet_20200812-5bf4721e.pth",
+  "mobilenet_v2": "https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth"
+}
diff --git a/annotator/uniformer/mmcv/model_zoo/open_mmlab.json b/annotator/uniformer/mmcv/model_zoo/open_mmlab.json
new file mode 100644
index 0000000000000000000000000000000000000000..8311db4feef92faa0841c697d75efbee8430c3a0
--- /dev/null
+++ b/annotator/uniformer/mmcv/model_zoo/open_mmlab.json
@@ -0,0 +1,50 @@
+{
+  "vgg16_caffe": "https://download.openmmlab.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
+  "detectron/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
+  "detectron2/resnet50_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet50_msra-5891d200.pth",
+  "detectron/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
+  "detectron2/resnet101_caffe": "https://download.openmmlab.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
+  "detectron2/resnext101_32x8d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x8d-1516f1aa.pth",
+  "resnext50_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
+  "resnext101_32x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
+  "resnext101_64x4d": "https://download.openmmlab.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
+  "contrib/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
+  "detectron/resnet50_gn": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
+  "detectron/resnet101_gn": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
+  "jhu/resnet50_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
+  "jhu/resnet101_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
+  "jhu/resnext50_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
+  "jhu/resnext101_32x4d_gn_ws": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
+  "jhu/resnext50_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
+  "jhu/resnext101_32x4d_gn": "https://download.openmmlab.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
+  "msra/hrnetv2_w18_small": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18_small-b5a04e21.pth",
+  "msra/hrnetv2_w18": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
+  "msra/hrnetv2_w32": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
+  "msra/hrnetv2_w40": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
+  "msra/hrnetv2_w48": "https://download.openmmlab.com/pretrain/third_party/hrnetv2_w48-d2186c55.pth",
+  "bninception_caffe": "https://download.openmmlab.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
+  "kin400/i3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
+  "kin400/nl3d_r50_f32s2_k400": "https://download.openmmlab.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
+  "res2net101_v1d_26w_4s": "https://download.openmmlab.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
+  "regnetx_400mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_400mf-a5b10d96.pth",
+  "regnetx_800mf": "https://download.openmmlab.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
+  "regnetx_1.6gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
+  "regnetx_3.2gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
+  "regnetx_4.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
+  "regnetx_6.4gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
+  "regnetx_8.0gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
+  "regnetx_12gf": "https://download.openmmlab.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth",
+  "resnet18_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet18_v1c-b5776b93.pth",
+  "resnet50_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet50_v1c-2cccc1ad.pth",
+  "resnet101_v1c": "https://download.openmmlab.com/pretrain/third_party/resnet101_v1c-e67eebb6.pth",
+  "mmedit/vgg16": "https://download.openmmlab.com/mmediting/third_party/vgg_state_dict.pth",
+  "mmedit/res34_en_nomixup": "https://download.openmmlab.com/mmediting/third_party/model_best_resnet34_En_nomixup.pth",
+  "mmedit/mobilenet_v2": "https://download.openmmlab.com/mmediting/third_party/mobilenet_v2.pth",
+  "contrib/mobilenet_v3_large": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_large-bc2c3fd3.pth",
+  "contrib/mobilenet_v3_small": "https://download.openmmlab.com/pretrain/third_party/mobilenet_v3_small-47085aa1.pth",
+  "resnest50": "https://download.openmmlab.com/pretrain/third_party/resnest50_d2-7497a55b.pth",
+  "resnest101": "https://download.openmmlab.com/pretrain/third_party/resnest101_d2-f3b931b2.pth",
+  "resnest200": "https://download.openmmlab.com/pretrain/third_party/resnest200_d2-ca88e41f.pth",
+  "darknet53": "https://download.openmmlab.com/pretrain/third_party/darknet53-a628ea1b.pth",
+  "mmdet/mobilenet_v2": "https://download.openmmlab.com/mmdetection/v2.0/third_party/mobilenet_v2_batch256_imagenet-ff34753d.pth"
+}
diff --git a/annotator/uniformer/mmcv/ops/__init__.py b/annotator/uniformer/mmcv/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..999e090a458ee148ceca0649f1e3806a40e909bd
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/__init__.py
@@ -0,0 +1,81 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .assign_score_withk import assign_score_withk
+from .ball_query import ball_query
+from .bbox import bbox_overlaps
+from .border_align import BorderAlign, border_align
+from .box_iou_rotated import box_iou_rotated
+from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
+from .cc_attention import CrissCrossAttention
+from .contour_expand import contour_expand
+from .corner_pool import CornerPool
+from .correlation import Correlation
+from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
+from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
+                              ModulatedDeformRoIPoolPack, deform_roi_pool)
+from .deprecated_wrappers import Conv2d_deprecated as Conv2d
+from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
+from .deprecated_wrappers import Linear_deprecated as Linear
+from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
+from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
+                         sigmoid_focal_loss, softmax_focal_loss)
+from .furthest_point_sample import (furthest_point_sample,
+                                    furthest_point_sample_with_dist)
+from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
+from .gather_points import gather_points
+from .group_points import GroupAll, QueryAndGroup, grouping_operation
+from .info import (get_compiler_version, get_compiling_cuda_version,
+                   get_onnxruntime_op_path)
+from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
+from .knn import knn
+from .masked_conv import MaskedConv2d, masked_conv2d
+from .modulated_deform_conv import (ModulatedDeformConv2d,
+                                    ModulatedDeformConv2dPack,
+                                    modulated_deform_conv2d)
+from .multi_scale_deform_attn import MultiScaleDeformableAttention
+from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
+from .pixel_group import pixel_group
+from .point_sample import (SimpleRoIAlign, point_sample,
+                           rel_roi_point_to_rel_img_point)
+from .points_in_boxes import (points_in_boxes_all, points_in_boxes_cpu,
+                              points_in_boxes_part)
+from .points_sampler import PointsSampler
+from .psa_mask import PSAMask
+from .roi_align import RoIAlign, roi_align
+from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
+from .roi_pool import RoIPool, roi_pool
+from .roiaware_pool3d import RoIAwarePool3d
+from .roipoint_pool3d import RoIPointPool3d
+from .saconv import SAConv2d
+from .scatter_points import DynamicScatter, dynamic_scatter
+from .sync_bn import SyncBatchNorm
+from .three_interpolate import three_interpolate
+from .three_nn import three_nn
+from .tin_shift import TINShift, tin_shift
+from .upfirdn2d import upfirdn2d
+from .voxelize import Voxelization, voxelization
+
+__all__ = [
+    'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
+    'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
+    'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
+    'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
+    'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
+    'get_compiler_version', 'get_compiling_cuda_version',
+    'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
+    'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
+    'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
+    'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
+    'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
+    'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
+    'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
+    'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
+    'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
+    'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup',
+    'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn',
+    'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
+    'border_align', 'gather_points', 'furthest_point_sample',
+    'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
+    'boxes_iou_bev', 'nms_bev', 'nms_normal_bev', 'Voxelization',
+    'voxelization', 'dynamic_scatter', 'DynamicScatter', 'RoIAwarePool3d',
+    'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all'
+]
diff --git a/annotator/uniformer/mmcv/ops/assign_score_withk.py b/annotator/uniformer/mmcv/ops/assign_score_withk.py
new file mode 100644
index 0000000000000000000000000000000000000000..4906adaa2cffd1b46912fbe7d4f87ef2f9fa0012
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/assign_score_withk.py
@@ -0,0 +1,123 @@
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['assign_score_withk_forward', 'assign_score_withk_backward'])
+
+
+class AssignScoreWithK(Function):
+    r"""Perform weighted sum to generate output features according to scores.
+    Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
+    scene_seg/lib/paconv_lib/src/gpu>`_.
+
+    This is a memory-efficient CUDA implementation of assign_scores operation,
+    which first transform all point features with weight bank, then assemble
+    neighbor features with ``knn_idx`` and perform weighted sum of ``scores``.
+
+    See the `paper <https://arxiv.org/pdf/2103.14635.pdf>`_ appendix Sec. D for
+        more detailed descriptions.
+
+    Note:
+        This implementation assumes using ``neighbor`` kernel input, which is
+            (point_features - center_features, point_features).
+        See https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/
+        pointnet2/paconv.py#L128 for more details.
+    """
+
+    @staticmethod
+    def forward(ctx,
+                scores,
+                point_features,
+                center_features,
+                knn_idx,
+                aggregate='sum'):
+        """
+        Args:
+            scores (torch.Tensor): (B, npoint, K, M), predicted scores to
+                aggregate weight matrices in the weight bank.
+                ``npoint`` is the number of sampled centers.
+                ``K`` is the number of queried neighbors.
+                ``M`` is the number of weight matrices in the weight bank.
+            point_features (torch.Tensor): (B, N, M, out_dim)
+                Pre-computed point features to be aggregated.
+            center_features (torch.Tensor): (B, N, M, out_dim)
+                Pre-computed center features to be aggregated.
+            knn_idx (torch.Tensor): (B, npoint, K), index of sampled kNN.
+                We assume the first idx in each row is the idx of the center.
+            aggregate (str, optional): Aggregation method.
+                Can be 'sum', 'avg' or 'max'. Defaults: 'sum'.
+
+        Returns:
+            torch.Tensor: (B, out_dim, npoint, K), the aggregated features.
+        """
+        agg = {'sum': 0, 'avg': 1, 'max': 2}
+
+        B, N, M, out_dim = point_features.size()
+        _, npoint, K, _ = scores.size()
+
+        output = point_features.new_zeros((B, out_dim, npoint, K))
+        ext_module.assign_score_withk_forward(
+            point_features.contiguous(),
+            center_features.contiguous(),
+            scores.contiguous(),
+            knn_idx.contiguous(),
+            output,
+            B=B,
+            N0=N,
+            N1=npoint,
+            M=M,
+            K=K,
+            O=out_dim,
+            aggregate=agg[aggregate])
+
+        ctx.save_for_backward(output, point_features, center_features, scores,
+                              knn_idx)
+        ctx.agg = agg[aggregate]
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_out):
+        """
+        Args:
+            grad_out (torch.Tensor): (B, out_dim, npoint, K)
+
+        Returns:
+            grad_scores (torch.Tensor): (B, npoint, K, M)
+            grad_point_features (torch.Tensor): (B, N, M, out_dim)
+            grad_center_features (torch.Tensor): (B, N, M, out_dim)
+        """
+        _, point_features, center_features, scores, knn_idx = ctx.saved_tensors
+
+        agg = ctx.agg
+
+        B, N, M, out_dim = point_features.size()
+        _, npoint, K, _ = scores.size()
+
+        grad_point_features = point_features.new_zeros(point_features.shape)
+        grad_center_features = center_features.new_zeros(center_features.shape)
+        grad_scores = scores.new_zeros(scores.shape)
+
+        ext_module.assign_score_withk_backward(
+            grad_out.contiguous(),
+            point_features.contiguous(),
+            center_features.contiguous(),
+            scores.contiguous(),
+            knn_idx.contiguous(),
+            grad_point_features,
+            grad_center_features,
+            grad_scores,
+            B=B,
+            N0=N,
+            N1=npoint,
+            M=M,
+            K=K,
+            O=out_dim,
+            aggregate=agg)
+
+        return grad_scores, grad_point_features, \
+            grad_center_features, None, None
+
+
+assign_score_withk = AssignScoreWithK.apply
diff --git a/annotator/uniformer/mmcv/ops/ball_query.py b/annotator/uniformer/mmcv/ops/ball_query.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0466847c6e5c1239e359a0397568413ebc1504a
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/ball_query.py
@@ -0,0 +1,55 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
+
+
+class BallQuery(Function):
+    """Find nearby points in spherical space."""
+
+    @staticmethod
+    def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
+                xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            min_radius (float): minimum radius of the balls.
+            max_radius (float): maximum radius of the balls.
+            sample_num (int): maximum number of features in the balls.
+            xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+            center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
+
+        Returns:
+            Tensor: (B, npoint, nsample) tensor with the indices of
+                the features that form the query balls.
+        """
+        assert center_xyz.is_contiguous()
+        assert xyz.is_contiguous()
+        assert min_radius < max_radius
+
+        B, N, _ = xyz.size()
+        npoint = center_xyz.size(1)
+        idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
+
+        ext_module.ball_query_forward(
+            center_xyz,
+            xyz,
+            idx,
+            b=B,
+            n=N,
+            m=npoint,
+            min_radius=min_radius,
+            max_radius=max_radius,
+            nsample=sample_num)
+        if torch.__version__ != 'parrots':
+            ctx.mark_non_differentiable(idx)
+        return idx
+
+    @staticmethod
+    def backward(ctx, a=None):
+        return None, None, None, None
+
+
+ball_query = BallQuery.apply
diff --git a/annotator/uniformer/mmcv/ops/bbox.py b/annotator/uniformer/mmcv/ops/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4d58b6c91f652933974f519acd3403a833e906
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/bbox.py
@@ -0,0 +1,72 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['bbox_overlaps'])
+
+
+def bbox_overlaps(bboxes1, bboxes2, mode='iou', aligned=False, offset=0):
+    """Calculate overlap between two set of bboxes.
+
+    If ``aligned`` is ``False``, then calculate the ious between each bbox
+    of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+    bboxes1 and bboxes2.
+
+    Args:
+        bboxes1 (Tensor): shape (m, 4) in <x1, y1, x2, y2> format or empty.
+        bboxes2 (Tensor): shape (n, 4) in <x1, y1, x2, y2> format or empty.
+            If aligned is ``True``, then m and n must be equal.
+        mode (str): "iou" (intersection over union) or iof (intersection over
+            foreground).
+
+    Returns:
+        ious(Tensor): shape (m, n) if aligned == False else shape (m, 1)
+
+    Example:
+        >>> bboxes1 = torch.FloatTensor([
+        >>>     [0, 0, 10, 10],
+        >>>     [10, 10, 20, 20],
+        >>>     [32, 32, 38, 42],
+        >>> ])
+        >>> bboxes2 = torch.FloatTensor([
+        >>>     [0, 0, 10, 20],
+        >>>     [0, 10, 10, 19],
+        >>>     [10, 10, 20, 20],
+        >>> ])
+        >>> bbox_overlaps(bboxes1, bboxes2)
+        tensor([[0.5000, 0.0000, 0.0000],
+                [0.0000, 0.0000, 1.0000],
+                [0.0000, 0.0000, 0.0000]])
+
+    Example:
+        >>> empty = torch.FloatTensor([])
+        >>> nonempty = torch.FloatTensor([
+        >>>     [0, 0, 10, 9],
+        >>> ])
+        >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
+        >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
+        >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
+    """
+
+    mode_dict = {'iou': 0, 'iof': 1}
+    assert mode in mode_dict.keys()
+    mode_flag = mode_dict[mode]
+    # Either the boxes are empty or the length of boxes' last dimension is 4
+    assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
+    assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
+    assert offset == 1 or offset == 0
+
+    rows = bboxes1.size(0)
+    cols = bboxes2.size(0)
+    if aligned:
+        assert rows == cols
+
+    if rows * cols == 0:
+        return bboxes1.new(rows, 1) if aligned else bboxes1.new(rows, cols)
+
+    if aligned:
+        ious = bboxes1.new_zeros(rows)
+    else:
+        ious = bboxes1.new_zeros((rows, cols))
+    ext_module.bbox_overlaps(
+        bboxes1, bboxes2, ious, mode=mode_flag, aligned=aligned, offset=offset)
+    return ious
diff --git a/annotator/uniformer/mmcv/ops/border_align.py b/annotator/uniformer/mmcv/ops/border_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff305be328e9b0a15e1bbb5e6b41beb940f55c81
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/border_align.py
@@ -0,0 +1,109 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modified from
+# https://github.com/Megvii-BaseDetection/cvpods/blob/master/cvpods/layers/border_align.py
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['border_align_forward', 'border_align_backward'])
+
+
+class BorderAlignFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input, boxes, pool_size):
+        return g.op(
+            'mmcv::MMCVBorderAlign', input, boxes, pool_size_i=pool_size)
+
+    @staticmethod
+    def forward(ctx, input, boxes, pool_size):
+        ctx.pool_size = pool_size
+        ctx.input_shape = input.size()
+
+        assert boxes.ndim == 3, 'boxes must be with shape [B, H*W, 4]'
+        assert boxes.size(2) == 4, \
+            'the last dimension of boxes must be (x1, y1, x2, y2)'
+        assert input.size(1) % 4 == 0, \
+            'the channel for input feature must be divisible by factor 4'
+
+        # [B, C//4, H*W, 4]
+        output_shape = (input.size(0), input.size(1) // 4, boxes.size(1), 4)
+        output = input.new_zeros(output_shape)
+        # `argmax_idx` only used for backward
+        argmax_idx = input.new_zeros(output_shape).to(torch.int)
+
+        ext_module.border_align_forward(
+            input, boxes, output, argmax_idx, pool_size=ctx.pool_size)
+
+        ctx.save_for_backward(boxes, argmax_idx)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        boxes, argmax_idx = ctx.saved_tensors
+        grad_input = grad_output.new_zeros(ctx.input_shape)
+        # complex head architecture may cause grad_output uncontiguous
+        grad_output = grad_output.contiguous()
+        ext_module.border_align_backward(
+            grad_output,
+            boxes,
+            argmax_idx,
+            grad_input,
+            pool_size=ctx.pool_size)
+        return grad_input, None, None
+
+
+border_align = BorderAlignFunction.apply
+
+
+class BorderAlign(nn.Module):
+    r"""Border align pooling layer.
+
+    Applies border_align over the input feature based on predicted bboxes.
+    The details were described in the paper
+    `BorderDet: Border Feature for Dense Object Detection
+    <https://arxiv.org/abs/2007.11056>`_.
+
+    For each border line (e.g. top, left, bottom or right) of each box,
+    border_align does the following:
+        1. uniformly samples `pool_size`+1 positions on this line, involving \
+           the start and end points.
+        2. the corresponding features on these points are computed by \
+           bilinear interpolation.
+        3. max pooling over all the `pool_size`+1 positions are used for \
+           computing pooled feature.
+
+    Args:
+        pool_size (int): number of positions sampled over the boxes' borders
+            (e.g. top, bottom, left, right).
+
+    """
+
+    def __init__(self, pool_size):
+        super(BorderAlign, self).__init__()
+        self.pool_size = pool_size
+
+    def forward(self, input, boxes):
+        """
+        Args:
+            input: Features with shape [N,4C,H,W]. Channels ranged in [0,C),
+                [C,2C), [2C,3C), [3C,4C) represent the top, left, bottom,
+                right features respectively.
+            boxes: Boxes with shape [N,H*W,4]. Coordinate format (x1,y1,x2,y2).
+
+        Returns:
+            Tensor: Pooled features with shape [N,C,H*W,4]. The order is
+                (top,left,bottom,right) for the last dimension.
+        """
+        return border_align(input, boxes, self.pool_size)
+
+    def __repr__(self):
+        s = self.__class__.__name__
+        s += f'(pool_size={self.pool_size})'
+        return s
diff --git a/annotator/uniformer/mmcv/ops/box_iou_rotated.py b/annotator/uniformer/mmcv/ops/box_iou_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d78015e9c2a9e7a52859b4e18f84a9aa63481a0
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/box_iou_rotated.py
@@ -0,0 +1,45 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
+
+
+def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
+    """Return intersection-over-union (Jaccard index) of boxes.
+
+    Both sets of boxes are expected to be in
+    (x_center, y_center, width, height, angle) format.
+
+    If ``aligned`` is ``False``, then calculate the ious between each bbox
+    of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
+    bboxes1 and bboxes2.
+
+    Arguments:
+        boxes1 (Tensor): rotated bboxes 1. \
+            It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
+            Note that theta is in radian.
+        boxes2 (Tensor): rotated bboxes 2. \
+            It has shape (M, 5), indicating (x, y, w, h, theta) for each row.
+            Note that theta is in radian.
+        mode (str): "iou" (intersection over union) or iof (intersection over
+            foreground).
+
+    Returns:
+        ious(Tensor): shape (N, M) if aligned == False else shape (N,)
+    """
+    assert mode in ['iou', 'iof']
+    mode_dict = {'iou': 0, 'iof': 1}
+    mode_flag = mode_dict[mode]
+    rows = bboxes1.size(0)
+    cols = bboxes2.size(0)
+    if aligned:
+        ious = bboxes1.new_zeros(rows)
+    else:
+        ious = bboxes1.new_zeros((rows * cols))
+    bboxes1 = bboxes1.contiguous()
+    bboxes2 = bboxes2.contiguous()
+    ext_module.box_iou_rotated(
+        bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
+    if not aligned:
+        ious = ious.view(rows, cols)
+    return ious
diff --git a/annotator/uniformer/mmcv/ops/carafe.py b/annotator/uniformer/mmcv/ops/carafe.py
new file mode 100644
index 0000000000000000000000000000000000000000..5154cb3abfccfbbe0a1b2daa67018dbf80aaf6d2
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/carafe.py
@@ -0,0 +1,287 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.nn.modules.module import Module
+
+from ..cnn import UPSAMPLE_LAYERS, normal_init, xavier_init
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+    'carafe_naive_forward', 'carafe_naive_backward', 'carafe_forward',
+    'carafe_backward'
+])
+
+
+class CARAFENaiveFunction(Function):
+
+    @staticmethod
+    def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+        return g.op(
+            'mmcv::MMCVCARAFENaive',
+            features,
+            masks,
+            kernel_size_i=kernel_size,
+            group_size_i=group_size,
+            scale_factor_f=scale_factor)
+
+    @staticmethod
+    def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+        assert scale_factor >= 1
+        assert masks.size(1) == kernel_size * kernel_size * group_size
+        assert masks.size(-1) == features.size(-1) * scale_factor
+        assert masks.size(-2) == features.size(-2) * scale_factor
+        assert features.size(1) % group_size == 0
+        assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+        ctx.kernel_size = kernel_size
+        ctx.group_size = group_size
+        ctx.scale_factor = scale_factor
+        ctx.feature_size = features.size()
+        ctx.mask_size = masks.size()
+
+        n, c, h, w = features.size()
+        output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+        ext_module.carafe_naive_forward(
+            features,
+            masks,
+            output,
+            kernel_size=kernel_size,
+            group_size=group_size,
+            scale_factor=scale_factor)
+
+        if features.requires_grad or masks.requires_grad:
+            ctx.save_for_backward(features, masks)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        assert grad_output.is_cuda
+
+        features, masks = ctx.saved_tensors
+        kernel_size = ctx.kernel_size
+        group_size = ctx.group_size
+        scale_factor = ctx.scale_factor
+
+        grad_input = torch.zeros_like(features)
+        grad_masks = torch.zeros_like(masks)
+        ext_module.carafe_naive_backward(
+            grad_output.contiguous(),
+            features,
+            masks,
+            grad_input,
+            grad_masks,
+            kernel_size=kernel_size,
+            group_size=group_size,
+            scale_factor=scale_factor)
+
+        return grad_input, grad_masks, None, None, None
+
+
+carafe_naive = CARAFENaiveFunction.apply
+
+
+class CARAFENaive(Module):
+
+    def __init__(self, kernel_size, group_size, scale_factor):
+        super(CARAFENaive, self).__init__()
+
+        assert isinstance(kernel_size, int) and isinstance(
+            group_size, int) and isinstance(scale_factor, int)
+        self.kernel_size = kernel_size
+        self.group_size = group_size
+        self.scale_factor = scale_factor
+
+    def forward(self, features, masks):
+        return carafe_naive(features, masks, self.kernel_size, self.group_size,
+                            self.scale_factor)
+
+
+class CARAFEFunction(Function):
+
+    @staticmethod
+    def symbolic(g, features, masks, kernel_size, group_size, scale_factor):
+        return g.op(
+            'mmcv::MMCVCARAFE',
+            features,
+            masks,
+            kernel_size_i=kernel_size,
+            group_size_i=group_size,
+            scale_factor_f=scale_factor)
+
+    @staticmethod
+    def forward(ctx, features, masks, kernel_size, group_size, scale_factor):
+        assert scale_factor >= 1
+        assert masks.size(1) == kernel_size * kernel_size * group_size
+        assert masks.size(-1) == features.size(-1) * scale_factor
+        assert masks.size(-2) == features.size(-2) * scale_factor
+        assert features.size(1) % group_size == 0
+        assert (kernel_size - 1) % 2 == 0 and kernel_size >= 1
+        ctx.kernel_size = kernel_size
+        ctx.group_size = group_size
+        ctx.scale_factor = scale_factor
+        ctx.feature_size = features.size()
+        ctx.mask_size = masks.size()
+
+        n, c, h, w = features.size()
+        output = features.new_zeros((n, c, h * scale_factor, w * scale_factor))
+        routput = features.new_zeros(output.size(), requires_grad=False)
+        rfeatures = features.new_zeros(features.size(), requires_grad=False)
+        rmasks = masks.new_zeros(masks.size(), requires_grad=False)
+        ext_module.carafe_forward(
+            features,
+            masks,
+            rfeatures,
+            routput,
+            rmasks,
+            output,
+            kernel_size=kernel_size,
+            group_size=group_size,
+            scale_factor=scale_factor)
+
+        if features.requires_grad or masks.requires_grad:
+            ctx.save_for_backward(features, masks, rfeatures)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        assert grad_output.is_cuda
+
+        features, masks, rfeatures = ctx.saved_tensors
+        kernel_size = ctx.kernel_size
+        group_size = ctx.group_size
+        scale_factor = ctx.scale_factor
+
+        rgrad_output = torch.zeros_like(grad_output, requires_grad=False)
+        rgrad_input_hs = torch.zeros_like(grad_output, requires_grad=False)
+        rgrad_input = torch.zeros_like(features, requires_grad=False)
+        rgrad_masks = torch.zeros_like(masks, requires_grad=False)
+        grad_input = torch.zeros_like(features, requires_grad=False)
+        grad_masks = torch.zeros_like(masks, requires_grad=False)
+        ext_module.carafe_backward(
+            grad_output.contiguous(),
+            rfeatures,
+            masks,
+            rgrad_output,
+            rgrad_input_hs,
+            rgrad_input,
+            rgrad_masks,
+            grad_input,
+            grad_masks,
+            kernel_size=kernel_size,
+            group_size=group_size,
+            scale_factor=scale_factor)
+        return grad_input, grad_masks, None, None, None
+
+
+carafe = CARAFEFunction.apply
+
+
+class CARAFE(Module):
+    """ CARAFE: Content-Aware ReAssembly of FEatures
+
+    Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+    Args:
+        kernel_size (int): reassemble kernel size
+        group_size (int): reassemble group size
+        scale_factor (int): upsample ratio
+
+    Returns:
+        upsampled feature map
+    """
+
+    def __init__(self, kernel_size, group_size, scale_factor):
+        super(CARAFE, self).__init__()
+
+        assert isinstance(kernel_size, int) and isinstance(
+            group_size, int) and isinstance(scale_factor, int)
+        self.kernel_size = kernel_size
+        self.group_size = group_size
+        self.scale_factor = scale_factor
+
+    def forward(self, features, masks):
+        return carafe(features, masks, self.kernel_size, self.group_size,
+                      self.scale_factor)
+
+
+@UPSAMPLE_LAYERS.register_module(name='carafe')
+class CARAFEPack(nn.Module):
+    """A unified package of CARAFE upsampler that contains: 1) channel
+    compressor 2) content encoder 3) CARAFE op.
+
+    Official implementation of ICCV 2019 paper
+    CARAFE: Content-Aware ReAssembly of FEatures
+    Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+    Args:
+        channels (int): input feature channels
+        scale_factor (int): upsample ratio
+        up_kernel (int): kernel size of CARAFE op
+        up_group (int): group size of CARAFE op
+        encoder_kernel (int): kernel size of content encoder
+        encoder_dilation (int): dilation of content encoder
+        compressed_channels (int): output channels of channels compressor
+
+    Returns:
+        upsampled feature map
+    """
+
+    def __init__(self,
+                 channels,
+                 scale_factor,
+                 up_kernel=5,
+                 up_group=1,
+                 encoder_kernel=3,
+                 encoder_dilation=1,
+                 compressed_channels=64):
+        super(CARAFEPack, self).__init__()
+        self.channels = channels
+        self.scale_factor = scale_factor
+        self.up_kernel = up_kernel
+        self.up_group = up_group
+        self.encoder_kernel = encoder_kernel
+        self.encoder_dilation = encoder_dilation
+        self.compressed_channels = compressed_channels
+        self.channel_compressor = nn.Conv2d(channels, self.compressed_channels,
+                                            1)
+        self.content_encoder = nn.Conv2d(
+            self.compressed_channels,
+            self.up_kernel * self.up_kernel * self.up_group *
+            self.scale_factor * self.scale_factor,
+            self.encoder_kernel,
+            padding=int((self.encoder_kernel - 1) * self.encoder_dilation / 2),
+            dilation=self.encoder_dilation,
+            groups=1)
+        self.init_weights()
+
+    def init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                xavier_init(m, distribution='uniform')
+        normal_init(self.content_encoder, std=0.001)
+
+    def kernel_normalizer(self, mask):
+        mask = F.pixel_shuffle(mask, self.scale_factor)
+        n, mask_c, h, w = mask.size()
+        # use float division explicitly,
+        # to void inconsistency while exporting to onnx
+        mask_channel = int(mask_c / float(self.up_kernel**2))
+        mask = mask.view(n, mask_channel, -1, h, w)
+
+        mask = F.softmax(mask, dim=2, dtype=mask.dtype)
+        mask = mask.view(n, mask_c, h, w).contiguous()
+
+        return mask
+
+    def feature_reassemble(self, x, mask):
+        x = carafe(x, mask, self.up_kernel, self.up_group, self.scale_factor)
+        return x
+
+    def forward(self, x):
+        compressed_x = self.channel_compressor(x)
+        mask = self.content_encoder(compressed_x)
+        mask = self.kernel_normalizer(mask)
+
+        x = self.feature_reassemble(x, mask)
+        return x
diff --git a/annotator/uniformer/mmcv/ops/cc_attention.py b/annotator/uniformer/mmcv/ops/cc_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..9207aa95e6730bd9b3362dee612059a5f0ce1c5e
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/cc_attention.py
@@ -0,0 +1,83 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.uniformer.mmcv.cnn import PLUGIN_LAYERS, Scale
+
+
+def NEG_INF_DIAG(n, device):
+    """Returns a diagonal matrix of size [n, n].
+
+    The diagonal are all "-inf". This is for avoiding calculating the
+    overlapped element in the Criss-Cross twice.
+    """
+    return torch.diag(torch.tensor(float('-inf')).to(device).repeat(n), 0)
+
+
+@PLUGIN_LAYERS.register_module()
+class CrissCrossAttention(nn.Module):
+    """Criss-Cross Attention Module.
+
+    .. note::
+        Before v1.3.13, we use a CUDA op. Since v1.3.13, we switch
+        to a pure PyTorch and equivalent implementation. For more
+        details, please refer to https://github.com/open-mmlab/mmcv/pull/1201.
+
+        Speed comparison for one forward pass
+
+        - Input size: [2,512,97,97]
+        - Device: 1 NVIDIA GeForce RTX 2080 Ti
+
+        +-----------------------+---------------+------------+---------------+
+        |                       |PyTorch version|CUDA version|Relative speed |
+        +=======================+===============+============+===============+
+        |with torch.no_grad()   |0.00554402 s   |0.0299619 s |5.4x           |
+        +-----------------------+---------------+------------+---------------+
+        |no with torch.no_grad()|0.00562803 s   |0.0301349 s |5.4x           |
+        +-----------------------+---------------+------------+---------------+
+
+    Args:
+        in_channels (int): Channels of the input feature map.
+    """
+
+    def __init__(self, in_channels):
+        super().__init__()
+        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1)
+        self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
+        self.gamma = Scale(0.)
+        self.in_channels = in_channels
+
+    def forward(self, x):
+        """forward function of Criss-Cross Attention.
+
+        Args:
+            x (Tensor): Input feature. \
+                shape (batch_size, in_channels, height, width)
+        Returns:
+            Tensor: Output of the layer, with shape of \
+            (batch_size, in_channels, height, width)
+        """
+        B, C, H, W = x.size()
+        query = self.query_conv(x)
+        key = self.key_conv(x)
+        value = self.value_conv(x)
+        energy_H = torch.einsum('bchw,bciw->bwhi', query, key) + NEG_INF_DIAG(
+            H, query.device)
+        energy_H = energy_H.transpose(1, 2)
+        energy_W = torch.einsum('bchw,bchj->bhwj', query, key)
+        attn = F.softmax(
+            torch.cat([energy_H, energy_W], dim=-1), dim=-1)  # [B,H,W,(H+W)]
+        out = torch.einsum('bciw,bhwi->bchw', value, attn[..., :H])
+        out += torch.einsum('bchj,bhwj->bchw', value, attn[..., H:])
+
+        out = self.gamma(out) + x
+        out = out.contiguous()
+
+        return out
+
+    def __repr__(self):
+        s = self.__class__.__name__
+        s += f'(in_channels={self.in_channels})'
+        return s
diff --git a/annotator/uniformer/mmcv/ops/contour_expand.py b/annotator/uniformer/mmcv/ops/contour_expand.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea1111e1768b5f27e118bf7dbc0d9c70a7afd6d7
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/contour_expand.py
@@ -0,0 +1,49 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['contour_expand'])
+
+
+def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
+                   kernel_num):
+    """Expand kernel contours so that foreground pixels are assigned into
+    instances.
+
+    Arguments:
+        kernel_mask (np.array or Tensor): The instance kernel mask with
+            size hxw.
+        internal_kernel_label (np.array or Tensor): The instance internal
+            kernel label with size hxw.
+        min_kernel_area (int): The minimum kernel area.
+        kernel_num (int): The instance kernel number.
+
+    Returns:
+        label (list): The instance index map with size hxw.
+    """
+    assert isinstance(kernel_mask, (torch.Tensor, np.ndarray))
+    assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray))
+    assert isinstance(min_kernel_area, int)
+    assert isinstance(kernel_num, int)
+
+    if isinstance(kernel_mask, np.ndarray):
+        kernel_mask = torch.from_numpy(kernel_mask)
+    if isinstance(internal_kernel_label, np.ndarray):
+        internal_kernel_label = torch.from_numpy(internal_kernel_label)
+
+    if torch.__version__ == 'parrots':
+        if kernel_mask.shape[0] == 0 or internal_kernel_label.shape[0] == 0:
+            label = []
+        else:
+            label = ext_module.contour_expand(
+                kernel_mask,
+                internal_kernel_label,
+                min_kernel_area=min_kernel_area,
+                kernel_num=kernel_num)
+            label = label.tolist()
+    else:
+        label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
+                                          min_kernel_area, kernel_num)
+    return label
diff --git a/annotator/uniformer/mmcv/ops/corner_pool.py b/annotator/uniformer/mmcv/ops/corner_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33d798b43d405e4c86bee4cd6389be21ca9c637
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/corner_pool.py
@@ -0,0 +1,161 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+    'top_pool_forward', 'top_pool_backward', 'bottom_pool_forward',
+    'bottom_pool_backward', 'left_pool_forward', 'left_pool_backward',
+    'right_pool_forward', 'right_pool_backward'
+])
+
+_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
+
+
+class TopPoolFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input):
+        output = g.op(
+            'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
+        return output
+
+    @staticmethod
+    def forward(ctx, input):
+        output = ext_module.top_pool_forward(input)
+        ctx.save_for_backward(input)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input, = ctx.saved_tensors
+        output = ext_module.top_pool_backward(input, grad_output)
+        return output
+
+
+class BottomPoolFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input):
+        output = g.op(
+            'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
+        return output
+
+    @staticmethod
+    def forward(ctx, input):
+        output = ext_module.bottom_pool_forward(input)
+        ctx.save_for_backward(input)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input, = ctx.saved_tensors
+        output = ext_module.bottom_pool_backward(input, grad_output)
+        return output
+
+
+class LeftPoolFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input):
+        output = g.op(
+            'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
+        return output
+
+    @staticmethod
+    def forward(ctx, input):
+        output = ext_module.left_pool_forward(input)
+        ctx.save_for_backward(input)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input, = ctx.saved_tensors
+        output = ext_module.left_pool_backward(input, grad_output)
+        return output
+
+
+class RightPoolFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input):
+        output = g.op(
+            'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
+        return output
+
+    @staticmethod
+    def forward(ctx, input):
+        output = ext_module.right_pool_forward(input)
+        ctx.save_for_backward(input)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input, = ctx.saved_tensors
+        output = ext_module.right_pool_backward(input, grad_output)
+        return output
+
+
+class CornerPool(nn.Module):
+    """Corner Pooling.
+
+    Corner Pooling is a new type of pooling layer that helps a
+    convolutional network better localize corners of bounding boxes.
+
+    Please refer to https://arxiv.org/abs/1808.01244 for more details.
+    Code is modified from https://github.com/princeton-vl/CornerNet-Lite.
+
+    Args:
+        mode(str): Pooling orientation for the pooling layer
+
+            - 'bottom': Bottom Pooling
+            - 'left': Left Pooling
+            - 'right': Right Pooling
+            - 'top': Top Pooling
+
+    Returns:
+        Feature map after pooling.
+    """
+
+    pool_functions = {
+        'bottom': BottomPoolFunction,
+        'left': LeftPoolFunction,
+        'right': RightPoolFunction,
+        'top': TopPoolFunction,
+    }
+
+    cummax_dim_flip = {
+        'bottom': (2, False),
+        'left': (3, True),
+        'right': (3, False),
+        'top': (2, True),
+    }
+
+    def __init__(self, mode):
+        super(CornerPool, self).__init__()
+        assert mode in self.pool_functions
+        self.mode = mode
+        self.corner_pool = self.pool_functions[mode]
+
+    def forward(self, x):
+        if torch.__version__ != 'parrots' and torch.__version__ >= '1.5.0':
+            if torch.onnx.is_in_onnx_export():
+                assert torch.__version__ >= '1.7.0', \
+                    'When `cummax` serves as an intermediate component whose '\
+                    'outputs is used as inputs for another modules, it\'s '\
+                    'expected that pytorch version must be >= 1.7.0, '\
+                    'otherwise Error appears like: `RuntimeError: tuple '\
+                    'appears in op that does not forward tuples, unsupported '\
+                    'kind: prim::PythonOp`.'
+
+            dim, flip = self.cummax_dim_flip[self.mode]
+            if flip:
+                x = x.flip(dim)
+            pool_tensor, _ = torch.cummax(x, dim=dim)
+            if flip:
+                pool_tensor = pool_tensor.flip(dim)
+            return pool_tensor
+        else:
+            return self.corner_pool.apply(x)
diff --git a/annotator/uniformer/mmcv/ops/correlation.py b/annotator/uniformer/mmcv/ops/correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d0b79c301b29915dfaf4d2b1846c59be73127d3
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/correlation.py
@@ -0,0 +1,196 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import Tensor, nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['correlation_forward', 'correlation_backward'])
+
+
+class CorrelationFunction(Function):
+
+    @staticmethod
+    def forward(ctx,
+                input1,
+                input2,
+                kernel_size=1,
+                max_displacement=1,
+                stride=1,
+                padding=1,
+                dilation=1,
+                dilation_patch=1):
+
+        ctx.save_for_backward(input1, input2)
+
+        kH, kW = ctx.kernel_size = _pair(kernel_size)
+        patch_size = max_displacement * 2 + 1
+        ctx.patch_size = patch_size
+        dH, dW = ctx.stride = _pair(stride)
+        padH, padW = ctx.padding = _pair(padding)
+        dilationH, dilationW = ctx.dilation = _pair(dilation)
+        dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(
+            dilation_patch)
+
+        output_size = CorrelationFunction._output_size(ctx, input1)
+
+        output = input1.new_zeros(output_size)
+
+        ext_module.correlation_forward(
+            input1,
+            input2,
+            output,
+            kH=kH,
+            kW=kW,
+            patchH=patch_size,
+            patchW=patch_size,
+            padH=padH,
+            padW=padW,
+            dilationH=dilationH,
+            dilationW=dilationW,
+            dilation_patchH=dilation_patchH,
+            dilation_patchW=dilation_patchW,
+            dH=dH,
+            dW=dW)
+
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input1, input2 = ctx.saved_tensors
+
+        kH, kW = ctx.kernel_size
+        patch_size = ctx.patch_size
+        padH, padW = ctx.padding
+        dilationH, dilationW = ctx.dilation
+        dilation_patchH, dilation_patchW = ctx.dilation_patch
+        dH, dW = ctx.stride
+        grad_input1 = torch.zeros_like(input1)
+        grad_input2 = torch.zeros_like(input2)
+
+        ext_module.correlation_backward(
+            grad_output,
+            input1,
+            input2,
+            grad_input1,
+            grad_input2,
+            kH=kH,
+            kW=kW,
+            patchH=patch_size,
+            patchW=patch_size,
+            padH=padH,
+            padW=padW,
+            dilationH=dilationH,
+            dilationW=dilationW,
+            dilation_patchH=dilation_patchH,
+            dilation_patchW=dilation_patchW,
+            dH=dH,
+            dW=dW)
+        return grad_input1, grad_input2, None, None, None, None, None, None
+
+    @staticmethod
+    def _output_size(ctx, input1):
+        iH, iW = input1.size(2), input1.size(3)
+        batch_size = input1.size(0)
+        kH, kW = ctx.kernel_size
+        patch_size = ctx.patch_size
+        dH, dW = ctx.stride
+        padH, padW = ctx.padding
+        dilationH, dilationW = ctx.dilation
+        dilatedKH = (kH - 1) * dilationH + 1
+        dilatedKW = (kW - 1) * dilationW + 1
+
+        oH = int((iH + 2 * padH - dilatedKH) / dH + 1)
+        oW = int((iW + 2 * padW - dilatedKW) / dW + 1)
+
+        output_size = (batch_size, patch_size, patch_size, oH, oW)
+        return output_size
+
+
+class Correlation(nn.Module):
+    r"""Correlation operator
+
+    This correlation operator works for optical flow correlation computation.
+
+    There are two batched tensors with shape :math:`(N, C, H, W)`,
+    and the correlation output's shape is :math:`(N, max\_displacement \times
+    2 + 1, max\_displacement * 2 + 1, H_{out}, W_{out})`
+
+    where
+
+    .. math::
+        H_{out} = \left\lfloor\frac{H_{in}  + 2 \times padding -
+            dilation \times (kernel\_size - 1) - 1}
+            {stride} + 1\right\rfloor
+
+    .. math::
+        W_{out} = \left\lfloor\frac{W_{in}  + 2 \times padding - dilation
+            \times (kernel\_size - 1) - 1}
+            {stride} + 1\right\rfloor
+
+    the correlation item :math:`(N_i, dy, dx)` is formed by taking the sliding
+    window convolution between input1 and shifted input2,
+
+    .. math::
+        Corr(N_i, dx, dy) =
+        \sum_{c=0}^{C-1}
+        input1(N_i, c) \star
+        \mathcal{S}(input2(N_i, c), dy, dx)
+
+    where :math:`\star` is the valid 2d sliding window convolution operator,
+    and :math:`\mathcal{S}` means shifting the input features (auto-complete
+    zero marginal), and :math:`dx, dy` are shifting distance, :math:`dx, dy \in
+    [-max\_displacement \times dilation\_patch, max\_displacement \times
+    dilation\_patch]`.
+
+    Args:
+        kernel_size (int): The size of sliding window i.e. local neighborhood
+            representing the center points and involved in correlation
+            computation. Defaults to 1.
+        max_displacement (int): The radius for computing correlation volume,
+            but the actual working space can be dilated by dilation_patch.
+            Defaults to 1.
+        stride (int): The stride of the sliding blocks in the input spatial
+            dimensions. Defaults to 1.
+        padding (int): Zero padding added to all four sides of the input1.
+            Defaults to 0.
+        dilation (int): The spacing of local neighborhood that will involved
+            in correlation. Defaults to 1.
+        dilation_patch (int): The spacing between position need to compute
+            correlation.  Defaults to 1.
+    """
+
+    def __init__(self,
+                 kernel_size: int = 1,
+                 max_displacement: int = 1,
+                 stride: int = 1,
+                 padding: int = 0,
+                 dilation: int = 1,
+                 dilation_patch: int = 1) -> None:
+        super().__init__()
+        self.kernel_size = kernel_size
+        self.max_displacement = max_displacement
+        self.stride = stride
+        self.padding = padding
+        self.dilation = dilation
+        self.dilation_patch = dilation_patch
+
+    def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
+        return CorrelationFunction.apply(input1, input2, self.kernel_size,
+                                         self.max_displacement, self.stride,
+                                         self.padding, self.dilation,
+                                         self.dilation_patch)
+
+    def __repr__(self) -> str:
+        s = self.__class__.__name__
+        s += f'(kernel_size={self.kernel_size}, '
+        s += f'max_displacement={self.max_displacement}, '
+        s += f'stride={self.stride}, '
+        s += f'padding={self.padding}, '
+        s += f'dilation={self.dilation}, '
+        s += f'dilation_patch={self.dilation_patch})'
+        return s
diff --git a/annotator/uniformer/mmcv/ops/deform_conv.py b/annotator/uniformer/mmcv/ops/deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f8c75ee774823eea334e3b3732af6a18f55038
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/deform_conv.py
@@ -0,0 +1,405 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+
+from annotator.uniformer.mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+
+ext_module = ext_loader.load_ext('_ext', [
+    'deform_conv_forward', 'deform_conv_backward_input',
+    'deform_conv_backward_parameters'
+])
+
+
+class DeformConv2dFunction(Function):
+
+    @staticmethod
+    def symbolic(g,
+                 input,
+                 offset,
+                 weight,
+                 stride,
+                 padding,
+                 dilation,
+                 groups,
+                 deform_groups,
+                 bias=False,
+                 im2col_step=32):
+        return g.op(
+            'mmcv::MMCVDeformConv2d',
+            input,
+            offset,
+            weight,
+            stride_i=stride,
+            padding_i=padding,
+            dilation_i=dilation,
+            groups_i=groups,
+            deform_groups_i=deform_groups,
+            bias_i=bias,
+            im2col_step_i=im2col_step)
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                offset,
+                weight,
+                stride=1,
+                padding=0,
+                dilation=1,
+                groups=1,
+                deform_groups=1,
+                bias=False,
+                im2col_step=32):
+        if input is not None and input.dim() != 4:
+            raise ValueError(
+                f'Expected 4D tensor as input, got {input.dim()}D tensor \
+                  instead.')
+        assert bias is False, 'Only support bias is False.'
+        ctx.stride = _pair(stride)
+        ctx.padding = _pair(padding)
+        ctx.dilation = _pair(dilation)
+        ctx.groups = groups
+        ctx.deform_groups = deform_groups
+        ctx.im2col_step = im2col_step
+
+        # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+        # amp won't cast the type of model (float32), but "offset" is cast
+        # to float16 by nn.Conv2d automatically, leading to the type
+        # mismatch with input (when it is float32) or weight.
+        # The flag for whether to use fp16 or amp is the type of "offset",
+        # we cast weight and input to temporarily support fp16 and amp
+        # whatever the pytorch version is.
+        input = input.type_as(offset)
+        weight = weight.type_as(input)
+        ctx.save_for_backward(input, offset, weight)
+
+        output = input.new_empty(
+            DeformConv2dFunction._output_size(ctx, input, weight))
+
+        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones
+
+        cur_im2col_step = min(ctx.im2col_step, input.size(0))
+        assert (input.size(0) %
+                cur_im2col_step) == 0, 'im2col step must divide batchsize'
+        ext_module.deform_conv_forward(
+            input,
+            weight,
+            offset,
+            output,
+            ctx.bufs_[0],
+            ctx.bufs_[1],
+            kW=weight.size(3),
+            kH=weight.size(2),
+            dW=ctx.stride[1],
+            dH=ctx.stride[0],
+            padW=ctx.padding[1],
+            padH=ctx.padding[0],
+            dilationW=ctx.dilation[1],
+            dilationH=ctx.dilation[0],
+            group=ctx.groups,
+            deformable_group=ctx.deform_groups,
+            im2col_step=cur_im2col_step)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input, offset, weight = ctx.saved_tensors
+
+        grad_input = grad_offset = grad_weight = None
+
+        cur_im2col_step = min(ctx.im2col_step, input.size(0))
+        assert (input.size(0) % cur_im2col_step
+                ) == 0, 'batch size must be divisible by im2col_step'
+
+        grad_output = grad_output.contiguous()
+        if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+            grad_input = torch.zeros_like(input)
+            grad_offset = torch.zeros_like(offset)
+            ext_module.deform_conv_backward_input(
+                input,
+                offset,
+                grad_output,
+                grad_input,
+                grad_offset,
+                weight,
+                ctx.bufs_[0],
+                kW=weight.size(3),
+                kH=weight.size(2),
+                dW=ctx.stride[1],
+                dH=ctx.stride[0],
+                padW=ctx.padding[1],
+                padH=ctx.padding[0],
+                dilationW=ctx.dilation[1],
+                dilationH=ctx.dilation[0],
+                group=ctx.groups,
+                deformable_group=ctx.deform_groups,
+                im2col_step=cur_im2col_step)
+
+        if ctx.needs_input_grad[2]:
+            grad_weight = torch.zeros_like(weight)
+            ext_module.deform_conv_backward_parameters(
+                input,
+                offset,
+                grad_output,
+                grad_weight,
+                ctx.bufs_[0],
+                ctx.bufs_[1],
+                kW=weight.size(3),
+                kH=weight.size(2),
+                dW=ctx.stride[1],
+                dH=ctx.stride[0],
+                padW=ctx.padding[1],
+                padH=ctx.padding[0],
+                dilationW=ctx.dilation[1],
+                dilationH=ctx.dilation[0],
+                group=ctx.groups,
+                deformable_group=ctx.deform_groups,
+                scale=1,
+                im2col_step=cur_im2col_step)
+
+        return grad_input, grad_offset, grad_weight, \
+            None, None, None, None, None, None, None
+
+    @staticmethod
+    def _output_size(ctx, input, weight):
+        channels = weight.size(0)
+        output_size = (input.size(0), channels)
+        for d in range(input.dim() - 2):
+            in_size = input.size(d + 2)
+            pad = ctx.padding[d]
+            kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+            stride_ = ctx.stride[d]
+            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+        if not all(map(lambda s: s > 0, output_size)):
+            raise ValueError(
+                'convolution input is too small (output would be ' +
+                'x'.join(map(str, output_size)) + ')')
+        return output_size
+
+
+deform_conv2d = DeformConv2dFunction.apply
+
+
+class DeformConv2d(nn.Module):
+    r"""Deformable 2D convolution.
+
+    Applies a deformable 2D convolution over an input signal composed of
+    several input planes. DeformConv2d was described in the paper
+    `Deformable Convolutional Networks
+    <https://arxiv.org/pdf/1703.06211.pdf>`_
+
+    Note:
+        The argument ``im2col_step`` was added in version 1.3.17, which means
+        number of samples processed by the ``im2col_cuda_kernel`` per call.
+        It enables users to define ``batch_size`` and ``im2col_step`` more
+        flexibly and solved `issue mmcv#1440
+        <https://github.com/open-mmlab/mmcv/issues/1440>`_.
+
+    Args:
+        in_channels (int): Number of channels in the input image.
+        out_channels (int): Number of channels produced by the convolution.
+        kernel_size(int, tuple): Size of the convolving kernel.
+        stride(int, tuple): Stride of the convolution. Default: 1.
+        padding (int or tuple): Zero-padding added to both sides of the input.
+            Default: 0.
+        dilation (int or tuple): Spacing between kernel elements. Default: 1.
+        groups (int): Number of blocked connections from input.
+            channels to output channels. Default: 1.
+        deform_groups (int): Number of deformable group partitions.
+        bias (bool): If True, adds a learnable bias to the output.
+            Default: False.
+        im2col_step (int): Number of samples processed by im2col_cuda_kernel
+            per call. It will work when ``batch_size`` > ``im2col_step``, but
+            ``batch_size`` must be divisible by ``im2col_step``. Default: 32.
+            `New in version 1.3.17.`
+    """
+
+    @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+                            cls_name='DeformConv2d')
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 kernel_size: Union[int, Tuple[int, ...]],
+                 stride: Union[int, Tuple[int, ...]] = 1,
+                 padding: Union[int, Tuple[int, ...]] = 0,
+                 dilation: Union[int, Tuple[int, ...]] = 1,
+                 groups: int = 1,
+                 deform_groups: int = 1,
+                 bias: bool = False,
+                 im2col_step: int = 32) -> None:
+        super(DeformConv2d, self).__init__()
+
+        assert not bias, \
+            f'bias={bias} is not supported in DeformConv2d.'
+        assert in_channels % groups == 0, \
+            f'in_channels {in_channels} cannot be divisible by groups {groups}'
+        assert out_channels % groups == 0, \
+            f'out_channels {out_channels} cannot be divisible by groups \
+              {groups}'
+
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = _pair(kernel_size)
+        self.stride = _pair(stride)
+        self.padding = _pair(padding)
+        self.dilation = _pair(dilation)
+        self.groups = groups
+        self.deform_groups = deform_groups
+        self.im2col_step = im2col_step
+        # enable compatibility with nn.Conv2d
+        self.transposed = False
+        self.output_padding = _single(0)
+
+        # only weight, no bias
+        self.weight = nn.Parameter(
+            torch.Tensor(out_channels, in_channels // self.groups,
+                         *self.kernel_size))
+
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        # switch the initialization of `self.weight` to the standard kaiming
+        # method described in `Delving deep into rectifiers: Surpassing
+        # human-level performance on ImageNet classification` - He, K. et al.
+        # (2015), using a uniform distribution
+        nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
+
+    def forward(self, x: Tensor, offset: Tensor) -> Tensor:
+        """Deformable Convolutional forward function.
+
+        Args:
+            x (Tensor): Input feature, shape (B, C_in, H_in, W_in)
+            offset (Tensor): Offset for deformable convolution, shape
+                (B, deform_groups*kernel_size[0]*kernel_size[1]*2,
+                H_out, W_out), H_out, W_out are equal to the output's.
+
+                An offset is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+                The spatial arrangement is like:
+
+                .. code:: text
+
+                    (x0, y0) (x1, y1) (x2, y2)
+                    (x3, y3) (x4, y4) (x5, y5)
+                    (x6, y6) (x7, y7) (x8, y8)
+
+        Returns:
+            Tensor: Output of the layer.
+        """
+        # To fix an assert error in deform_conv_cuda.cpp:128
+        # input image is smaller than kernel
+        input_pad = (x.size(2) < self.kernel_size[0]) or (x.size(3) <
+                                                          self.kernel_size[1])
+        if input_pad:
+            pad_h = max(self.kernel_size[0] - x.size(2), 0)
+            pad_w = max(self.kernel_size[1] - x.size(3), 0)
+            x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+            offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0)
+            offset = offset.contiguous()
+        out = deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+                            self.dilation, self.groups, self.deform_groups,
+                            False, self.im2col_step)
+        if input_pad:
+            out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
+                      pad_w].contiguous()
+        return out
+
+    def __repr__(self):
+        s = self.__class__.__name__
+        s += f'(in_channels={self.in_channels},\n'
+        s += f'out_channels={self.out_channels},\n'
+        s += f'kernel_size={self.kernel_size},\n'
+        s += f'stride={self.stride},\n'
+        s += f'padding={self.padding},\n'
+        s += f'dilation={self.dilation},\n'
+        s += f'groups={self.groups},\n'
+        s += f'deform_groups={self.deform_groups},\n'
+        # bias is not supported in DeformConv2d.
+        s += 'bias=False)'
+        return s
+
+
+@CONV_LAYERS.register_module('DCN')
+class DeformConv2dPack(DeformConv2d):
+    """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+    The offset tensor is like `[y0, x0, y1, x1, y2, x2, ..., y8, x8]`.
+    The spatial arrangement is like:
+
+    .. code:: text
+
+        (x0, y0) (x1, y1) (x2, y2)
+        (x3, y3) (x4, y4) (x5, y5)
+        (x6, y6) (x7, y7) (x8, y8)
+
+    Args:
+        in_channels (int): Same as nn.Conv2d.
+        out_channels (int): Same as nn.Conv2d.
+        kernel_size (int or tuple[int]): Same as nn.Conv2d.
+        stride (int or tuple[int]): Same as nn.Conv2d.
+        padding (int or tuple[int]): Same as nn.Conv2d.
+        dilation (int or tuple[int]): Same as nn.Conv2d.
+        groups (int): Same as nn.Conv2d.
+        bias (bool or str): If specified as `auto`, it will be decided by the
+            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+            False.
+    """
+
+    _version = 2
+
+    def __init__(self, *args, **kwargs):
+        super(DeformConv2dPack, self).__init__(*args, **kwargs)
+        self.conv_offset = nn.Conv2d(
+            self.in_channels,
+            self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+            kernel_size=self.kernel_size,
+            stride=_pair(self.stride),
+            padding=_pair(self.padding),
+            dilation=_pair(self.dilation),
+            bias=True)
+        self.init_offset()
+
+    def init_offset(self):
+        self.conv_offset.weight.data.zero_()
+        self.conv_offset.bias.data.zero_()
+
+    def forward(self, x):
+        offset = self.conv_offset(x)
+        return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
+                             self.dilation, self.groups, self.deform_groups,
+                             False, self.im2col_step)
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        version = local_metadata.get('version', None)
+
+        if version is None or version < 2:
+            # the key is different in early versions
+            # In version < 2, DeformConvPack loads previous benchmark models.
+            if (prefix + 'conv_offset.weight' not in state_dict
+                    and prefix[:-1] + '_offset.weight' in state_dict):
+                state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+                    prefix[:-1] + '_offset.weight')
+            if (prefix + 'conv_offset.bias' not in state_dict
+                    and prefix[:-1] + '_offset.bias' in state_dict):
+                state_dict[prefix +
+                           'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+                                                                '_offset.bias')
+
+        if version is not None and version > 1:
+            print_log(
+                f'DeformConv2dPack {prefix.rstrip(".")} is upgraded to '
+                'version 2.',
+                logger='root')
+
+        super()._load_from_state_dict(state_dict, prefix, local_metadata,
+                                      strict, missing_keys, unexpected_keys,
+                                      error_msgs)
diff --git a/annotator/uniformer/mmcv/ops/deform_roi_pool.py b/annotator/uniformer/mmcv/ops/deform_roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc245ba91fee252226ba22e76bb94a35db9a629b
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/deform_roi_pool.py
@@ -0,0 +1,204 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch import nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['deform_roi_pool_forward', 'deform_roi_pool_backward'])
+
+
+class DeformRoIPoolFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input, rois, offset, output_size, spatial_scale,
+                 sampling_ratio, gamma):
+        return g.op(
+            'mmcv::MMCVDeformRoIPool',
+            input,
+            rois,
+            offset,
+            pooled_height_i=output_size[0],
+            pooled_width_i=output_size[1],
+            spatial_scale_f=spatial_scale,
+            sampling_ratio_f=sampling_ratio,
+            gamma_f=gamma)
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                rois,
+                offset,
+                output_size,
+                spatial_scale=1.0,
+                sampling_ratio=0,
+                gamma=0.1):
+        if offset is None:
+            offset = input.new_zeros(0)
+        ctx.output_size = _pair(output_size)
+        ctx.spatial_scale = float(spatial_scale)
+        ctx.sampling_ratio = int(sampling_ratio)
+        ctx.gamma = float(gamma)
+
+        assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+        output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+                        ctx.output_size[1])
+        output = input.new_zeros(output_shape)
+
+        ext_module.deform_roi_pool_forward(
+            input,
+            rois,
+            offset,
+            output,
+            pooled_height=ctx.output_size[0],
+            pooled_width=ctx.output_size[1],
+            spatial_scale=ctx.spatial_scale,
+            sampling_ratio=ctx.sampling_ratio,
+            gamma=ctx.gamma)
+
+        ctx.save_for_backward(input, rois, offset)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input, rois, offset = ctx.saved_tensors
+        grad_input = grad_output.new_zeros(input.shape)
+        grad_offset = grad_output.new_zeros(offset.shape)
+
+        ext_module.deform_roi_pool_backward(
+            grad_output,
+            input,
+            rois,
+            offset,
+            grad_input,
+            grad_offset,
+            pooled_height=ctx.output_size[0],
+            pooled_width=ctx.output_size[1],
+            spatial_scale=ctx.spatial_scale,
+            sampling_ratio=ctx.sampling_ratio,
+            gamma=ctx.gamma)
+        if grad_offset.numel() == 0:
+            grad_offset = None
+        return grad_input, None, grad_offset, None, None, None, None
+
+
+deform_roi_pool = DeformRoIPoolFunction.apply
+
+
+class DeformRoIPool(nn.Module):
+
+    def __init__(self,
+                 output_size,
+                 spatial_scale=1.0,
+                 sampling_ratio=0,
+                 gamma=0.1):
+        super(DeformRoIPool, self).__init__()
+        self.output_size = _pair(output_size)
+        self.spatial_scale = float(spatial_scale)
+        self.sampling_ratio = int(sampling_ratio)
+        self.gamma = float(gamma)
+
+    def forward(self, input, rois, offset=None):
+        return deform_roi_pool(input, rois, offset, self.output_size,
+                               self.spatial_scale, self.sampling_ratio,
+                               self.gamma)
+
+
+class DeformRoIPoolPack(DeformRoIPool):
+
+    def __init__(self,
+                 output_size,
+                 output_channels,
+                 deform_fc_channels=1024,
+                 spatial_scale=1.0,
+                 sampling_ratio=0,
+                 gamma=0.1):
+        super(DeformRoIPoolPack, self).__init__(output_size, spatial_scale,
+                                                sampling_ratio, gamma)
+
+        self.output_channels = output_channels
+        self.deform_fc_channels = deform_fc_channels
+
+        self.offset_fc = nn.Sequential(
+            nn.Linear(
+                self.output_size[0] * self.output_size[1] *
+                self.output_channels, self.deform_fc_channels),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.deform_fc_channels,
+                      self.output_size[0] * self.output_size[1] * 2))
+        self.offset_fc[-1].weight.data.zero_()
+        self.offset_fc[-1].bias.data.zero_()
+
+    def forward(self, input, rois):
+        assert input.size(1) == self.output_channels
+        x = deform_roi_pool(input, rois, None, self.output_size,
+                            self.spatial_scale, self.sampling_ratio,
+                            self.gamma)
+        rois_num = rois.size(0)
+        offset = self.offset_fc(x.view(rois_num, -1))
+        offset = offset.view(rois_num, 2, self.output_size[0],
+                             self.output_size[1])
+        return deform_roi_pool(input, rois, offset, self.output_size,
+                               self.spatial_scale, self.sampling_ratio,
+                               self.gamma)
+
+
+class ModulatedDeformRoIPoolPack(DeformRoIPool):
+
+    def __init__(self,
+                 output_size,
+                 output_channels,
+                 deform_fc_channels=1024,
+                 spatial_scale=1.0,
+                 sampling_ratio=0,
+                 gamma=0.1):
+        super(ModulatedDeformRoIPoolPack,
+              self).__init__(output_size, spatial_scale, sampling_ratio, gamma)
+
+        self.output_channels = output_channels
+        self.deform_fc_channels = deform_fc_channels
+
+        self.offset_fc = nn.Sequential(
+            nn.Linear(
+                self.output_size[0] * self.output_size[1] *
+                self.output_channels, self.deform_fc_channels),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.deform_fc_channels, self.deform_fc_channels),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.deform_fc_channels,
+                      self.output_size[0] * self.output_size[1] * 2))
+        self.offset_fc[-1].weight.data.zero_()
+        self.offset_fc[-1].bias.data.zero_()
+
+        self.mask_fc = nn.Sequential(
+            nn.Linear(
+                self.output_size[0] * self.output_size[1] *
+                self.output_channels, self.deform_fc_channels),
+            nn.ReLU(inplace=True),
+            nn.Linear(self.deform_fc_channels,
+                      self.output_size[0] * self.output_size[1] * 1),
+            nn.Sigmoid())
+        self.mask_fc[2].weight.data.zero_()
+        self.mask_fc[2].bias.data.zero_()
+
+    def forward(self, input, rois):
+        assert input.size(1) == self.output_channels
+        x = deform_roi_pool(input, rois, None, self.output_size,
+                            self.spatial_scale, self.sampling_ratio,
+                            self.gamma)
+        rois_num = rois.size(0)
+        offset = self.offset_fc(x.view(rois_num, -1))
+        offset = offset.view(rois_num, 2, self.output_size[0],
+                             self.output_size[1])
+        mask = self.mask_fc(x.view(rois_num, -1))
+        mask = mask.view(rois_num, 1, self.output_size[0], self.output_size[1])
+        d = deform_roi_pool(input, rois, offset, self.output_size,
+                            self.spatial_scale, self.sampling_ratio,
+                            self.gamma)
+        return d * mask
diff --git a/annotator/uniformer/mmcv/ops/deprecated_wrappers.py b/annotator/uniformer/mmcv/ops/deprecated_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2e593df9ee57637038683d7a1efaa347b2b69e7
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/deprecated_wrappers.py
@@ -0,0 +1,43 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# This file is for backward compatibility.
+# Module wrappers for empty tensor have been moved to mmcv.cnn.bricks.
+import warnings
+
+from ..cnn.bricks.wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
+
+
+class Conv2d_deprecated(Conv2d):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        warnings.warn(
+            'Importing Conv2d wrapper from "mmcv.ops" will be deprecated in'
+            ' the future. Please import them from "mmcv.cnn" instead')
+
+
+class ConvTranspose2d_deprecated(ConvTranspose2d):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        warnings.warn(
+            'Importing ConvTranspose2d wrapper from "mmcv.ops" will be '
+            'deprecated in the future. Please import them from "mmcv.cnn" '
+            'instead')
+
+
+class MaxPool2d_deprecated(MaxPool2d):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        warnings.warn(
+            'Importing MaxPool2d wrapper from "mmcv.ops" will be deprecated in'
+            ' the future. Please import them from "mmcv.cnn" instead')
+
+
+class Linear_deprecated(Linear):
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        warnings.warn(
+            'Importing Linear wrapper from "mmcv.ops" will be deprecated in'
+            ' the future. Please import them from "mmcv.cnn" instead')
diff --git a/annotator/uniformer/mmcv/ops/focal_loss.py b/annotator/uniformer/mmcv/ops/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..763bc93bd2575c49ca8ccf20996bbd92d1e0d1a4
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/focal_loss.py
@@ -0,0 +1,212 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+    'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward',
+    'softmax_focal_loss_forward', 'softmax_focal_loss_backward'
+])
+
+
+class SigmoidFocalLossFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input, target, gamma, alpha, weight, reduction):
+        return g.op(
+            'mmcv::MMCVSigmoidFocalLoss',
+            input,
+            target,
+            gamma_f=gamma,
+            alpha_f=alpha,
+            weight_f=weight,
+            reduction_s=reduction)
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                target,
+                gamma=2.0,
+                alpha=0.25,
+                weight=None,
+                reduction='mean'):
+
+        assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+        assert input.dim() == 2
+        assert target.dim() == 1
+        assert input.size(0) == target.size(0)
+        if weight is None:
+            weight = input.new_empty(0)
+        else:
+            assert weight.dim() == 1
+            assert input.size(1) == weight.size(0)
+        ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+        assert reduction in ctx.reduction_dict.keys()
+
+        ctx.gamma = float(gamma)
+        ctx.alpha = float(alpha)
+        ctx.reduction = ctx.reduction_dict[reduction]
+
+        output = input.new_zeros(input.size())
+
+        ext_module.sigmoid_focal_loss_forward(
+            input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
+        if ctx.reduction == ctx.reduction_dict['mean']:
+            output = output.sum() / input.size(0)
+        elif ctx.reduction == ctx.reduction_dict['sum']:
+            output = output.sum()
+        ctx.save_for_backward(input, target, weight)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input, target, weight = ctx.saved_tensors
+
+        grad_input = input.new_zeros(input.size())
+
+        ext_module.sigmoid_focal_loss_backward(
+            input,
+            target,
+            weight,
+            grad_input,
+            gamma=ctx.gamma,
+            alpha=ctx.alpha)
+
+        grad_input *= grad_output
+        if ctx.reduction == ctx.reduction_dict['mean']:
+            grad_input /= input.size(0)
+        return grad_input, None, None, None, None, None
+
+
+sigmoid_focal_loss = SigmoidFocalLossFunction.apply
+
+
+class SigmoidFocalLoss(nn.Module):
+
+    def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+        super(SigmoidFocalLoss, self).__init__()
+        self.gamma = gamma
+        self.alpha = alpha
+        self.register_buffer('weight', weight)
+        self.reduction = reduction
+
+    def forward(self, input, target):
+        return sigmoid_focal_loss(input, target, self.gamma, self.alpha,
+                                  self.weight, self.reduction)
+
+    def __repr__(self):
+        s = self.__class__.__name__
+        s += f'(gamma={self.gamma}, '
+        s += f'alpha={self.alpha}, '
+        s += f'reduction={self.reduction})'
+        return s
+
+
+class SoftmaxFocalLossFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input, target, gamma, alpha, weight, reduction):
+        return g.op(
+            'mmcv::MMCVSoftmaxFocalLoss',
+            input,
+            target,
+            gamma_f=gamma,
+            alpha_f=alpha,
+            weight_f=weight,
+            reduction_s=reduction)
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                target,
+                gamma=2.0,
+                alpha=0.25,
+                weight=None,
+                reduction='mean'):
+
+        assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor))
+        assert input.dim() == 2
+        assert target.dim() == 1
+        assert input.size(0) == target.size(0)
+        if weight is None:
+            weight = input.new_empty(0)
+        else:
+            assert weight.dim() == 1
+            assert input.size(1) == weight.size(0)
+        ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
+        assert reduction in ctx.reduction_dict.keys()
+
+        ctx.gamma = float(gamma)
+        ctx.alpha = float(alpha)
+        ctx.reduction = ctx.reduction_dict[reduction]
+
+        channel_stats, _ = torch.max(input, dim=1)
+        input_softmax = input - channel_stats.unsqueeze(1).expand_as(input)
+        input_softmax.exp_()
+
+        channel_stats = input_softmax.sum(dim=1)
+        input_softmax /= channel_stats.unsqueeze(1).expand_as(input)
+
+        output = input.new_zeros(input.size(0))
+        ext_module.softmax_focal_loss_forward(
+            input_softmax,
+            target,
+            weight,
+            output,
+            gamma=ctx.gamma,
+            alpha=ctx.alpha)
+
+        if ctx.reduction == ctx.reduction_dict['mean']:
+            output = output.sum() / input.size(0)
+        elif ctx.reduction == ctx.reduction_dict['sum']:
+            output = output.sum()
+        ctx.save_for_backward(input_softmax, target, weight)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input_softmax, target, weight = ctx.saved_tensors
+        buff = input_softmax.new_zeros(input_softmax.size(0))
+        grad_input = input_softmax.new_zeros(input_softmax.size())
+
+        ext_module.softmax_focal_loss_backward(
+            input_softmax,
+            target,
+            weight,
+            buff,
+            grad_input,
+            gamma=ctx.gamma,
+            alpha=ctx.alpha)
+
+        grad_input *= grad_output
+        if ctx.reduction == ctx.reduction_dict['mean']:
+            grad_input /= input_softmax.size(0)
+        return grad_input, None, None, None, None, None
+
+
+softmax_focal_loss = SoftmaxFocalLossFunction.apply
+
+
+class SoftmaxFocalLoss(nn.Module):
+
+    def __init__(self, gamma, alpha, weight=None, reduction='mean'):
+        super(SoftmaxFocalLoss, self).__init__()
+        self.gamma = gamma
+        self.alpha = alpha
+        self.register_buffer('weight', weight)
+        self.reduction = reduction
+
+    def forward(self, input, target):
+        return softmax_focal_loss(input, target, self.gamma, self.alpha,
+                                  self.weight, self.reduction)
+
+    def __repr__(self):
+        s = self.__class__.__name__
+        s += f'(gamma={self.gamma}, '
+        s += f'alpha={self.alpha}, '
+        s += f'reduction={self.reduction})'
+        return s
diff --git a/annotator/uniformer/mmcv/ops/furthest_point_sample.py b/annotator/uniformer/mmcv/ops/furthest_point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..374b7a878f1972c183941af28ba1df216ac1a60f
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/furthest_point_sample.py
@@ -0,0 +1,83 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+    'furthest_point_sampling_forward',
+    'furthest_point_sampling_with_dist_forward'
+])
+
+
+class FurthestPointSampling(Function):
+    """Uses iterative furthest point sampling to select a set of features whose
+    corresponding points have the furthest distance."""
+
+    @staticmethod
+    def forward(ctx, points_xyz: torch.Tensor,
+                num_points: int) -> torch.Tensor:
+        """
+        Args:
+            points_xyz (Tensor): (B, N, 3) where N > num_points.
+            num_points (int): Number of points in the sampled set.
+
+        Returns:
+             Tensor: (B, num_points) indices of the sampled points.
+        """
+        assert points_xyz.is_contiguous()
+
+        B, N = points_xyz.size()[:2]
+        output = torch.cuda.IntTensor(B, num_points)
+        temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+
+        ext_module.furthest_point_sampling_forward(
+            points_xyz,
+            temp,
+            output,
+            b=B,
+            n=N,
+            m=num_points,
+        )
+        if torch.__version__ != 'parrots':
+            ctx.mark_non_differentiable(output)
+        return output
+
+    @staticmethod
+    def backward(xyz, a=None):
+        return None, None
+
+
+class FurthestPointSamplingWithDist(Function):
+    """Uses iterative furthest point sampling to select a set of features whose
+    corresponding points have the furthest distance."""
+
+    @staticmethod
+    def forward(ctx, points_dist: torch.Tensor,
+                num_points: int) -> torch.Tensor:
+        """
+        Args:
+            points_dist (Tensor): (B, N, N) Distance between each point pair.
+            num_points (int): Number of points in the sampled set.
+
+        Returns:
+             Tensor: (B, num_points) indices of the sampled points.
+        """
+        assert points_dist.is_contiguous()
+
+        B, N, _ = points_dist.size()
+        output = points_dist.new_zeros([B, num_points], dtype=torch.int32)
+        temp = points_dist.new_zeros([B, N]).fill_(1e10)
+
+        ext_module.furthest_point_sampling_with_dist_forward(
+            points_dist, temp, output, b=B, n=N, m=num_points)
+        if torch.__version__ != 'parrots':
+            ctx.mark_non_differentiable(output)
+        return output
+
+    @staticmethod
+    def backward(xyz, a=None):
+        return None, None
+
+
+furthest_point_sample = FurthestPointSampling.apply
+furthest_point_sample_with_dist = FurthestPointSamplingWithDist.apply
diff --git a/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py b/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d12508469c6c8fa1884debece44c58d158cb6fa
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/fused_bias_leakyrelu.py
@@ -0,0 +1,268 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+
+# 1. Definitions
+
+# "Licensor" means any person or entity that distributes its Work.
+
+# "Software" means the original work of authorship made available under
+# this License.
+
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+
+# 2. License Grants
+
+#     2.1 Copyright Grant. Subject to the terms and conditions of this
+#     License, each Licensor grants to you a perpetual, worldwide,
+#     non-exclusive, royalty-free, copyright license to reproduce,
+#     prepare derivative works of, publicly display, publicly perform,
+#     sublicense and distribute its Work and any resulting derivative
+#     works in any form.
+
+# 3. Limitations
+
+#     3.1 Redistribution. You may reproduce or distribute the Work only
+#     if (a) you do so under this License, (b) you include a complete
+#     copy of this License with your distribution, and (c) you retain
+#     without modification any copyright, patent, trademark, or
+#     attribution notices that are present in the Work.
+
+#     3.2 Derivative Works. You may specify that additional or different
+#     terms apply to the use, reproduction, and distribution of your
+#     derivative works of the Work ("Your Terms") only if (a) Your Terms
+#     provide that the use limitation in Section 3.3 applies to your
+#     derivative works, and (b) you identify the specific derivative
+#     works that are subject to Your Terms. Notwithstanding Your Terms,
+#     this License (including the redistribution requirements in Section
+#     3.1) will continue to apply to the Work itself.
+
+#     3.3 Use Limitation. The Work and any derivative works thereof only
+#     may be used or intended for use non-commercially. Notwithstanding
+#     the foregoing, NVIDIA and its affiliates may use the Work and any
+#     derivative works commercially. As used herein, "non-commercially"
+#     means for research or evaluation purposes only.
+
+#     3.4 Patent Claims. If you bring or threaten to bring a patent claim
+#     against any Licensor (including any claim, cross-claim or
+#     counterclaim in a lawsuit) to enforce any patents that you allege
+#     are infringed by any Work, then your rights under this License from
+#     such Licensor (including the grant in Section 2.1) will terminate
+#     immediately.
+
+#     3.5 Trademarks. This License does not grant any rights to use any
+#     Licensor’s or its affiliates’ names, logos, or trademarks, except
+#     as necessary to reproduce the notices described in this License.
+
+#     3.6 Termination. If you violate any term of this License, then your
+#     rights under this License (including the grant in Section 2.1) will
+#     terminate immediately.
+
+# 4. Disclaimer of Warranty.
+
+# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+# THIS LICENSE.
+
+# 5. Limitation of Liability.
+
+# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGES.
+
+# =======================================================================
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['fused_bias_leakyrelu'])
+
+
+class FusedBiasLeakyReLUFunctionBackward(Function):
+    """Calculate second order deviation.
+
+    This function is to compute the second order deviation for the fused leaky
+    relu operation.
+    """
+
+    @staticmethod
+    def forward(ctx, grad_output, out, negative_slope, scale):
+        ctx.save_for_backward(out)
+        ctx.negative_slope = negative_slope
+        ctx.scale = scale
+
+        empty = grad_output.new_empty(0)
+
+        grad_input = ext_module.fused_bias_leakyrelu(
+            grad_output,
+            empty,
+            out,
+            act=3,
+            grad=1,
+            alpha=negative_slope,
+            scale=scale)
+
+        dim = [0]
+
+        if grad_input.ndim > 2:
+            dim += list(range(2, grad_input.ndim))
+
+        grad_bias = grad_input.sum(dim).detach()
+
+        return grad_input, grad_bias
+
+    @staticmethod
+    def backward(ctx, gradgrad_input, gradgrad_bias):
+        out, = ctx.saved_tensors
+
+        # The second order deviation, in fact, contains two parts, while the
+        # the first part is zero. Thus, we direct consider the second part
+        # which is similar with the first order deviation in implementation.
+        gradgrad_out = ext_module.fused_bias_leakyrelu(
+            gradgrad_input,
+            gradgrad_bias.to(out.dtype),
+            out,
+            act=3,
+            grad=1,
+            alpha=ctx.negative_slope,
+            scale=ctx.scale)
+
+        return gradgrad_out, None, None, None
+
+
+class FusedBiasLeakyReLUFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input, bias, negative_slope, scale):
+        empty = input.new_empty(0)
+
+        out = ext_module.fused_bias_leakyrelu(
+            input,
+            bias,
+            empty,
+            act=3,
+            grad=0,
+            alpha=negative_slope,
+            scale=scale)
+        ctx.save_for_backward(out)
+        ctx.negative_slope = negative_slope
+        ctx.scale = scale
+
+        return out
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        out, = ctx.saved_tensors
+
+        grad_input, grad_bias = FusedBiasLeakyReLUFunctionBackward.apply(
+            grad_output, out, ctx.negative_slope, ctx.scale)
+
+        return grad_input, grad_bias, None, None
+
+
+class FusedBiasLeakyReLU(nn.Module):
+    """Fused bias leaky ReLU.
+
+    This function is introduced in the StyleGAN2:
+    http://arxiv.org/abs/1912.04958
+
+    The bias term comes from the convolution operation. In addition, to keep
+    the variance of the feature map or gradients unchanged, they also adopt a
+    scale similarly with Kaiming initialization. However, since the
+    :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+    final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+    your own scale.
+
+    TODO: Implement the CPU version.
+
+    Args:
+        channel (int): The channel number of the feature map.
+        negative_slope (float, optional): Same as nn.LeakyRelu.
+            Defaults to 0.2.
+        scale (float, optional): A scalar to adjust the variance of the feature
+            map. Defaults to 2**0.5.
+    """
+
+    def __init__(self, num_channels, negative_slope=0.2, scale=2**0.5):
+        super(FusedBiasLeakyReLU, self).__init__()
+
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.negative_slope = negative_slope
+        self.scale = scale
+
+    def forward(self, input):
+        return fused_bias_leakyrelu(input, self.bias, self.negative_slope,
+                                    self.scale)
+
+
+def fused_bias_leakyrelu(input, bias, negative_slope=0.2, scale=2**0.5):
+    """Fused bias leaky ReLU function.
+
+    This function is introduced in the StyleGAN2:
+    http://arxiv.org/abs/1912.04958
+
+    The bias term comes from the convolution operation. In addition, to keep
+    the variance of the feature map or gradients unchanged, they also adopt a
+    scale similarly with Kaiming initialization. However, since the
+    :math:`1+{alpha}^2` : is too small, we can just ignore it. Therefore, the
+    final scale is just :math:`\sqrt{2}`:. Of course, you may change it with # noqa: W605, E501
+    your own scale.
+
+    Args:
+        input (torch.Tensor): Input feature map.
+        bias (nn.Parameter): The bias from convolution operation.
+        negative_slope (float, optional): Same as nn.LeakyRelu.
+            Defaults to 0.2.
+        scale (float, optional): A scalar to adjust the variance of the feature
+            map. Defaults to 2**0.5.
+
+    Returns:
+        torch.Tensor: Feature map after non-linear activation.
+    """
+
+    if not input.is_cuda:
+        return bias_leakyrelu_ref(input, bias, negative_slope, scale)
+
+    return FusedBiasLeakyReLUFunction.apply(input, bias.to(input.dtype),
+                                            negative_slope, scale)
+
+
+def bias_leakyrelu_ref(x, bias, negative_slope=0.2, scale=2**0.5):
+
+    if bias is not None:
+        assert bias.ndim == 1
+        assert bias.shape[0] == x.shape[1]
+        x = x + bias.reshape([-1 if i == 1 else 1 for i in range(x.ndim)])
+
+    x = F.leaky_relu(x, negative_slope)
+    if scale != 1:
+        x = x * scale
+
+    return x
diff --git a/annotator/uniformer/mmcv/ops/gather_points.py b/annotator/uniformer/mmcv/ops/gather_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..f52f1677d8ea0facafc56a3672d37adb44677ff3
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/gather_points.py
@@ -0,0 +1,57 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['gather_points_forward', 'gather_points_backward'])
+
+
+class GatherPoints(Function):
+    """Gather points with given index."""
+
+    @staticmethod
+    def forward(ctx, features: torch.Tensor,
+                indices: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            features (Tensor): (B, C, N) features to gather.
+            indices (Tensor): (B, M) where M is the number of points.
+
+        Returns:
+            Tensor: (B, C, M) where M is the number of points.
+        """
+        assert features.is_contiguous()
+        assert indices.is_contiguous()
+
+        B, npoint = indices.size()
+        _, C, N = features.size()
+        output = torch.cuda.FloatTensor(B, C, npoint)
+
+        ext_module.gather_points_forward(
+            features, indices, output, b=B, c=C, n=N, npoints=npoint)
+
+        ctx.for_backwards = (indices, C, N)
+        if torch.__version__ != 'parrots':
+            ctx.mark_non_differentiable(indices)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_out):
+        idx, C, N = ctx.for_backwards
+        B, npoint = idx.size()
+
+        grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+        grad_out_data = grad_out.data.contiguous()
+        ext_module.gather_points_backward(
+            grad_out_data,
+            idx,
+            grad_features.data,
+            b=B,
+            c=C,
+            n=N,
+            npoints=npoint)
+        return grad_features, None
+
+
+gather_points = GatherPoints.apply
diff --git a/annotator/uniformer/mmcv/ops/group_points.py b/annotator/uniformer/mmcv/ops/group_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3ec9d758ebe4e1c2205882af4be154008253a5
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/group_points.py
@@ -0,0 +1,224 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+from .ball_query import ball_query
+from .knn import knn
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['group_points_forward', 'group_points_backward'])
+
+
+class QueryAndGroup(nn.Module):
+    """Groups points with a ball query of radius.
+
+    Args:
+        max_radius (float): The maximum radius of the balls.
+            If None is given, we will use kNN sampling instead of ball query.
+        sample_num (int): Maximum number of features to gather in the ball.
+        min_radius (float, optional): The minimum radius of the balls.
+            Default: 0.
+        use_xyz (bool, optional): Whether to use xyz.
+            Default: True.
+        return_grouped_xyz (bool, optional): Whether to return grouped xyz.
+            Default: False.
+        normalize_xyz (bool, optional): Whether to normalize xyz.
+            Default: False.
+        uniform_sample (bool, optional): Whether to sample uniformly.
+            Default: False
+        return_unique_cnt (bool, optional): Whether to return the count of
+            unique samples. Default: False.
+        return_grouped_idx (bool, optional): Whether to return grouped idx.
+            Default: False.
+    """
+
+    def __init__(self,
+                 max_radius,
+                 sample_num,
+                 min_radius=0,
+                 use_xyz=True,
+                 return_grouped_xyz=False,
+                 normalize_xyz=False,
+                 uniform_sample=False,
+                 return_unique_cnt=False,
+                 return_grouped_idx=False):
+        super().__init__()
+        self.max_radius = max_radius
+        self.min_radius = min_radius
+        self.sample_num = sample_num
+        self.use_xyz = use_xyz
+        self.return_grouped_xyz = return_grouped_xyz
+        self.normalize_xyz = normalize_xyz
+        self.uniform_sample = uniform_sample
+        self.return_unique_cnt = return_unique_cnt
+        self.return_grouped_idx = return_grouped_idx
+        if self.return_unique_cnt:
+            assert self.uniform_sample, \
+                'uniform_sample should be True when ' \
+                'returning the count of unique samples'
+        if self.max_radius is None:
+            assert not self.normalize_xyz, \
+                'can not normalize grouped xyz when max_radius is None'
+
+    def forward(self, points_xyz, center_xyz, features=None):
+        """
+        Args:
+            points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+            center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods.
+            features (Tensor): (B, C, N) Descriptors of the features.
+
+        Returns:
+            Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
+        """
+        # if self.max_radius is None, we will perform kNN instead of ball query
+        # idx is of shape [B, npoint, sample_num]
+        if self.max_radius is None:
+            idx = knn(self.sample_num, points_xyz, center_xyz, False)
+            idx = idx.transpose(1, 2).contiguous()
+        else:
+            idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
+                             points_xyz, center_xyz)
+
+        if self.uniform_sample:
+            unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
+            for i_batch in range(idx.shape[0]):
+                for i_region in range(idx.shape[1]):
+                    unique_ind = torch.unique(idx[i_batch, i_region, :])
+                    num_unique = unique_ind.shape[0]
+                    unique_cnt[i_batch, i_region] = num_unique
+                    sample_ind = torch.randint(
+                        0,
+                        num_unique, (self.sample_num - num_unique, ),
+                        dtype=torch.long)
+                    all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
+                    idx[i_batch, i_region, :] = all_ind
+
+        xyz_trans = points_xyz.transpose(1, 2).contiguous()
+        # (B, 3, npoint, sample_num)
+        grouped_xyz = grouping_operation(xyz_trans, idx)
+        grouped_xyz_diff = grouped_xyz - \
+            center_xyz.transpose(1, 2).unsqueeze(-1)  # relative offsets
+        if self.normalize_xyz:
+            grouped_xyz_diff /= self.max_radius
+
+        if features is not None:
+            grouped_features = grouping_operation(features, idx)
+            if self.use_xyz:
+                # (B, C + 3, npoint, sample_num)
+                new_features = torch.cat([grouped_xyz_diff, grouped_features],
+                                         dim=1)
+            else:
+                new_features = grouped_features
+        else:
+            assert (self.use_xyz
+                    ), 'Cannot have not features and not use xyz as a feature!'
+            new_features = grouped_xyz_diff
+
+        ret = [new_features]
+        if self.return_grouped_xyz:
+            ret.append(grouped_xyz)
+        if self.return_unique_cnt:
+            ret.append(unique_cnt)
+        if self.return_grouped_idx:
+            ret.append(idx)
+        if len(ret) == 1:
+            return ret[0]
+        else:
+            return tuple(ret)
+
+
+class GroupAll(nn.Module):
+    """Group xyz with feature.
+
+    Args:
+        use_xyz (bool): Whether to use xyz.
+    """
+
+    def __init__(self, use_xyz: bool = True):
+        super().__init__()
+        self.use_xyz = use_xyz
+
+    def forward(self,
+                xyz: torch.Tensor,
+                new_xyz: torch.Tensor,
+                features: torch.Tensor = None):
+        """
+        Args:
+            xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+            new_xyz (Tensor): new xyz coordinates of the features.
+            features (Tensor): (B, C, N) features to group.
+
+        Returns:
+            Tensor: (B, C + 3, 1, N) Grouped feature.
+        """
+        grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+        if features is not None:
+            grouped_features = features.unsqueeze(2)
+            if self.use_xyz:
+                # (B, 3 + C, 1, N)
+                new_features = torch.cat([grouped_xyz, grouped_features],
+                                         dim=1)
+            else:
+                new_features = grouped_features
+        else:
+            new_features = grouped_xyz
+
+        return new_features
+
+
+class GroupingOperation(Function):
+    """Group feature with given index."""
+
+    @staticmethod
+    def forward(ctx, features: torch.Tensor,
+                indices: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            features (Tensor): (B, C, N) tensor of features to group.
+            indices (Tensor): (B, npoint, nsample) the indices of
+                features to group with.
+
+        Returns:
+            Tensor: (B, C, npoint, nsample) Grouped features.
+        """
+        features = features.contiguous()
+        indices = indices.contiguous()
+
+        B, nfeatures, nsample = indices.size()
+        _, C, N = features.size()
+        output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
+
+        ext_module.group_points_forward(B, C, N, nfeatures, nsample, features,
+                                        indices, output)
+
+        ctx.for_backwards = (indices, N)
+        return output
+
+    @staticmethod
+    def backward(ctx,
+                 grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+            grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
+                of the output from forward.
+
+        Returns:
+            Tensor: (B, C, N) gradient of the features.
+        """
+        idx, N = ctx.for_backwards
+
+        B, C, npoint, nsample = grad_out.size()
+        grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
+
+        grad_out_data = grad_out.data.contiguous()
+        ext_module.group_points_backward(B, C, N, npoint, nsample,
+                                         grad_out_data, idx,
+                                         grad_features.data)
+        return grad_features, None
+
+
+grouping_operation = GroupingOperation.apply
diff --git a/annotator/uniformer/mmcv/ops/info.py b/annotator/uniformer/mmcv/ops/info.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f2e5598ae2bb5866ccd15a7d3b4de33c0cd14d
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/info.py
@@ -0,0 +1,36 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import glob
+import os
+
+import torch
+
+if torch.__version__ == 'parrots':
+    import parrots
+
+    def get_compiler_version():
+        return 'GCC ' + parrots.version.compiler
+
+    def get_compiling_cuda_version():
+        return parrots.version.cuda
+else:
+    from ..utils import ext_loader
+    ext_module = ext_loader.load_ext(
+        '_ext', ['get_compiler_version', 'get_compiling_cuda_version'])
+
+    def get_compiler_version():
+        return ext_module.get_compiler_version()
+
+    def get_compiling_cuda_version():
+        return ext_module.get_compiling_cuda_version()
+
+
+def get_onnxruntime_op_path():
+    wildcard = os.path.join(
+        os.path.abspath(os.path.dirname(os.path.dirname(__file__))),
+        '_ext_ort.*.so')
+
+    paths = glob.glob(wildcard)
+    if len(paths) > 0:
+        return paths[0]
+    else:
+        return ''
diff --git a/annotator/uniformer/mmcv/ops/iou3d.py b/annotator/uniformer/mmcv/ops/iou3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc71979190323f44c09f8b7e1761cf49cd2d76b
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/iou3d.py
@@ -0,0 +1,85 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+    'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward',
+    'iou3d_nms_normal_forward'
+])
+
+
+def boxes_iou_bev(boxes_a, boxes_b):
+    """Calculate boxes IoU in the Bird's Eye View.
+
+    Args:
+        boxes_a (torch.Tensor): Input boxes a with shape (M, 5).
+        boxes_b (torch.Tensor): Input boxes b with shape (N, 5).
+
+    Returns:
+        ans_iou (torch.Tensor): IoU result with shape (M, N).
+    """
+    ans_iou = boxes_a.new_zeros(
+        torch.Size((boxes_a.shape[0], boxes_b.shape[0])))
+
+    ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(),
+                                           boxes_b.contiguous(), ans_iou)
+
+    return ans_iou
+
+
+def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
+    """NMS function GPU implementation (for BEV boxes). The overlap of two
+    boxes for IoU calculation is defined as the exact overlapping area of the
+    two boxes. In this function, one can also set ``pre_max_size`` and
+    ``post_max_size``.
+
+    Args:
+        boxes (torch.Tensor): Input boxes with the shape of [N, 5]
+            ([x1, y1, x2, y2, ry]).
+        scores (torch.Tensor): Scores of boxes with the shape of [N].
+        thresh (float): Overlap threshold of NMS.
+        pre_max_size (int, optional): Max size of boxes before NMS.
+            Default: None.
+        post_max_size (int, optional): Max size of boxes after NMS.
+            Default: None.
+
+    Returns:
+        torch.Tensor: Indexes after NMS.
+    """
+    assert boxes.size(1) == 5, 'Input boxes shape should be [N, 5]'
+    order = scores.sort(0, descending=True)[1]
+
+    if pre_max_size is not None:
+        order = order[:pre_max_size]
+    boxes = boxes[order].contiguous()
+
+    keep = torch.zeros(boxes.size(0), dtype=torch.long)
+    num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh)
+    keep = order[keep[:num_out].cuda(boxes.device)].contiguous()
+    if post_max_size is not None:
+        keep = keep[:post_max_size]
+    return keep
+
+
+def nms_normal_bev(boxes, scores, thresh):
+    """Normal NMS function GPU implementation (for BEV boxes). The overlap of
+    two boxes for IoU calculation is defined as the exact overlapping area of
+    the two boxes WITH their yaw angle set to 0.
+
+    Args:
+        boxes (torch.Tensor): Input boxes with shape (N, 5).
+        scores (torch.Tensor): Scores of predicted boxes with shape (N).
+        thresh (float): Overlap threshold of NMS.
+
+    Returns:
+        torch.Tensor: Remaining indices with scores in descending order.
+    """
+    assert boxes.shape[1] == 5, 'Input boxes shape should be [N, 5]'
+    order = scores.sort(0, descending=True)[1]
+
+    boxes = boxes[order].contiguous()
+
+    keep = torch.zeros(boxes.size(0), dtype=torch.long)
+    num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh)
+    return order[keep[:num_out].cuda(boxes.device)].contiguous()
diff --git a/annotator/uniformer/mmcv/ops/knn.py b/annotator/uniformer/mmcv/ops/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f335785036669fc19239825b0aae6dde3f73bf92
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/knn.py
@@ -0,0 +1,77 @@
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['knn_forward'])
+
+
+class KNN(Function):
+    r"""KNN (CUDA) based on heap data structure.
+    Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
+    scene_seg/lib/pointops/src/knnquery_heap>`_.
+
+    Find k-nearest points.
+    """
+
+    @staticmethod
+    def forward(ctx,
+                k: int,
+                xyz: torch.Tensor,
+                center_xyz: torch.Tensor = None,
+                transposed: bool = False) -> torch.Tensor:
+        """
+        Args:
+            k (int): number of nearest neighbors.
+            xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
+                xyz coordinates of the features.
+            center_xyz (Tensor, optional): (B, npoint, 3) if transposed ==
+                False, else (B, 3, npoint). centers of the knn query.
+                Default: None.
+            transposed (bool, optional): whether the input tensors are
+                transposed. Should not explicitly use this keyword when
+                calling knn (=KNN.apply), just add the fourth param.
+                Default: False.
+
+        Returns:
+            Tensor: (B, k, npoint) tensor with the indices of
+                the features that form k-nearest neighbours.
+        """
+        assert (k > 0) & (k < 100), 'k should be in range(0, 100)'
+
+        if center_xyz is None:
+            center_xyz = xyz
+
+        if transposed:
+            xyz = xyz.transpose(2, 1).contiguous()
+            center_xyz = center_xyz.transpose(2, 1).contiguous()
+
+        assert xyz.is_contiguous()  # [B, N, 3]
+        assert center_xyz.is_contiguous()  # [B, npoint, 3]
+
+        center_xyz_device = center_xyz.get_device()
+        assert center_xyz_device == xyz.get_device(), \
+            'center_xyz and xyz should be put on the same device'
+        if torch.cuda.current_device() != center_xyz_device:
+            torch.cuda.set_device(center_xyz_device)
+
+        B, npoint, _ = center_xyz.shape
+        N = xyz.shape[1]
+
+        idx = center_xyz.new_zeros((B, npoint, k)).int()
+        dist2 = center_xyz.new_zeros((B, npoint, k)).float()
+
+        ext_module.knn_forward(
+            xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
+        # idx shape to [B, k, npoint]
+        idx = idx.transpose(2, 1).contiguous()
+        if torch.__version__ != 'parrots':
+            ctx.mark_non_differentiable(idx)
+        return idx
+
+    @staticmethod
+    def backward(ctx, a=None):
+        return None, None, None
+
+
+knn = KNN.apply
diff --git a/annotator/uniformer/mmcv/ops/masked_conv.py b/annotator/uniformer/mmcv/ops/masked_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd514cc204c1d571ea5dc7e74b038c0f477a008b
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/masked_conv.py
@@ -0,0 +1,111 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['masked_im2col_forward', 'masked_col2im_forward'])
+
+
+class MaskedConv2dFunction(Function):
+
+    @staticmethod
+    def symbolic(g, features, mask, weight, bias, padding, stride):
+        return g.op(
+            'mmcv::MMCVMaskedConv2d',
+            features,
+            mask,
+            weight,
+            bias,
+            padding_i=padding,
+            stride_i=stride)
+
+    @staticmethod
+    def forward(ctx, features, mask, weight, bias, padding=0, stride=1):
+        assert mask.dim() == 3 and mask.size(0) == 1
+        assert features.dim() == 4 and features.size(0) == 1
+        assert features.size()[2:] == mask.size()[1:]
+        pad_h, pad_w = _pair(padding)
+        stride_h, stride_w = _pair(stride)
+        if stride_h != 1 or stride_w != 1:
+            raise ValueError(
+                'Stride could not only be 1 in masked_conv2d currently.')
+        out_channel, in_channel, kernel_h, kernel_w = weight.size()
+
+        batch_size = features.size(0)
+        out_h = int(
+            math.floor((features.size(2) + 2 * pad_h -
+                        (kernel_h - 1) - 1) / stride_h + 1))
+        out_w = int(
+            math.floor((features.size(3) + 2 * pad_w -
+                        (kernel_h - 1) - 1) / stride_w + 1))
+        mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False)
+        output = features.new_zeros(batch_size, out_channel, out_h, out_w)
+        if mask_inds.numel() > 0:
+            mask_h_idx = mask_inds[:, 0].contiguous()
+            mask_w_idx = mask_inds[:, 1].contiguous()
+            data_col = features.new_zeros(in_channel * kernel_h * kernel_w,
+                                          mask_inds.size(0))
+            ext_module.masked_im2col_forward(
+                features,
+                mask_h_idx,
+                mask_w_idx,
+                data_col,
+                kernel_h=kernel_h,
+                kernel_w=kernel_w,
+                pad_h=pad_h,
+                pad_w=pad_w)
+
+            masked_output = torch.addmm(1, bias[:, None], 1,
+                                        weight.view(out_channel, -1), data_col)
+            ext_module.masked_col2im_forward(
+                masked_output,
+                mask_h_idx,
+                mask_w_idx,
+                output,
+                height=out_h,
+                width=out_w,
+                channels=out_channel)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        return (None, ) * 5
+
+
+masked_conv2d = MaskedConv2dFunction.apply
+
+
+class MaskedConv2d(nn.Conv2d):
+    """A MaskedConv2d which inherits the official Conv2d.
+
+    The masked forward doesn't implement the backward function and only
+    supports the stride parameter to be 1 currently.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 bias=True):
+        super(MaskedConv2d,
+              self).__init__(in_channels, out_channels, kernel_size, stride,
+                             padding, dilation, groups, bias)
+
+    def forward(self, input, mask=None):
+        if mask is None:  # fallback to the normal Conv2d
+            return super(MaskedConv2d, self).forward(input)
+        else:
+            return masked_conv2d(input, mask, self.weight, self.bias,
+                                 self.padding)
diff --git a/annotator/uniformer/mmcv/ops/merge_cells.py b/annotator/uniformer/mmcv/ops/merge_cells.py
new file mode 100644
index 0000000000000000000000000000000000000000..48ca8cc0a8aca8432835bd760c0403a3c35b34cf
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/merge_cells.py
@@ -0,0 +1,149 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..cnn import ConvModule
+
+
+class BaseMergeCell(nn.Module):
+    """The basic class for cells used in NAS-FPN and NAS-FCOS.
+
+    BaseMergeCell takes 2 inputs. After applying convolution
+    on them, they are resized to the target size. Then,
+    they go through binary_op, which depends on the type of cell.
+    If with_out_conv is True, the result of output will go through
+    another convolution layer.
+
+    Args:
+        in_channels (int): number of input channels in out_conv layer.
+        out_channels (int): number of output channels in out_conv layer.
+        with_out_conv (bool): Whether to use out_conv layer
+        out_conv_cfg (dict): Config dict for convolution layer, which should
+            contain "groups", "kernel_size", "padding", "bias" to build
+            out_conv layer.
+        out_norm_cfg (dict): Config dict for normalization layer in out_conv.
+        out_conv_order (tuple): The order of conv/norm/activation layers in
+            out_conv.
+        with_input1_conv (bool): Whether to use convolution on input1.
+        with_input2_conv (bool): Whether to use convolution on input2.
+        input_conv_cfg (dict): Config dict for building input1_conv layer and
+            input2_conv layer, which is expected to contain the type of
+            convolution.
+            Default: None, which means using conv2d.
+        input_norm_cfg (dict): Config dict for normalization layer in
+            input1_conv and input2_conv layer. Default: None.
+        upsample_mode (str): Interpolation method used to resize the output
+            of input1_conv and input2_conv to target size. Currently, we
+            support ['nearest', 'bilinear']. Default: 'nearest'.
+    """
+
+    def __init__(self,
+                 fused_channels=256,
+                 out_channels=256,
+                 with_out_conv=True,
+                 out_conv_cfg=dict(
+                     groups=1, kernel_size=3, padding=1, bias=True),
+                 out_norm_cfg=None,
+                 out_conv_order=('act', 'conv', 'norm'),
+                 with_input1_conv=False,
+                 with_input2_conv=False,
+                 input_conv_cfg=None,
+                 input_norm_cfg=None,
+                 upsample_mode='nearest'):
+        super(BaseMergeCell, self).__init__()
+        assert upsample_mode in ['nearest', 'bilinear']
+        self.with_out_conv = with_out_conv
+        self.with_input1_conv = with_input1_conv
+        self.with_input2_conv = with_input2_conv
+        self.upsample_mode = upsample_mode
+
+        if self.with_out_conv:
+            self.out_conv = ConvModule(
+                fused_channels,
+                out_channels,
+                **out_conv_cfg,
+                norm_cfg=out_norm_cfg,
+                order=out_conv_order)
+
+        self.input1_conv = self._build_input_conv(
+            out_channels, input_conv_cfg,
+            input_norm_cfg) if with_input1_conv else nn.Sequential()
+        self.input2_conv = self._build_input_conv(
+            out_channels, input_conv_cfg,
+            input_norm_cfg) if with_input2_conv else nn.Sequential()
+
+    def _build_input_conv(self, channel, conv_cfg, norm_cfg):
+        return ConvModule(
+            channel,
+            channel,
+            3,
+            padding=1,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            bias=True)
+
+    @abstractmethod
+    def _binary_op(self, x1, x2):
+        pass
+
+    def _resize(self, x, size):
+        if x.shape[-2:] == size:
+            return x
+        elif x.shape[-2:] < size:
+            return F.interpolate(x, size=size, mode=self.upsample_mode)
+        else:
+            assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0
+            kernel_size = x.shape[-1] // size[-1]
+            x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
+            return x
+
+    def forward(self, x1, x2, out_size=None):
+        assert x1.shape[:2] == x2.shape[:2]
+        assert out_size is None or len(out_size) == 2
+        if out_size is None:  # resize to larger one
+            out_size = max(x1.size()[2:], x2.size()[2:])
+
+        x1 = self.input1_conv(x1)
+        x2 = self.input2_conv(x2)
+
+        x1 = self._resize(x1, out_size)
+        x2 = self._resize(x2, out_size)
+
+        x = self._binary_op(x1, x2)
+        if self.with_out_conv:
+            x = self.out_conv(x)
+        return x
+
+
+class SumCell(BaseMergeCell):
+
+    def __init__(self, in_channels, out_channels, **kwargs):
+        super(SumCell, self).__init__(in_channels, out_channels, **kwargs)
+
+    def _binary_op(self, x1, x2):
+        return x1 + x2
+
+
+class ConcatCell(BaseMergeCell):
+
+    def __init__(self, in_channels, out_channels, **kwargs):
+        super(ConcatCell, self).__init__(in_channels * 2, out_channels,
+                                         **kwargs)
+
+    def _binary_op(self, x1, x2):
+        ret = torch.cat([x1, x2], dim=1)
+        return ret
+
+
+class GlobalPoolingCell(BaseMergeCell):
+
+    def __init__(self, in_channels=None, out_channels=None, **kwargs):
+        super().__init__(in_channels, out_channels, **kwargs)
+        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
+
+    def _binary_op(self, x1, x2):
+        x2_att = self.global_pool(x2).sigmoid()
+        return x2 + x2_att * x1
diff --git a/annotator/uniformer/mmcv/ops/modulated_deform_conv.py b/annotator/uniformer/mmcv/ops/modulated_deform_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..75559579cf053abcc99538606cbb88c723faf783
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/modulated_deform_conv.py
@@ -0,0 +1,282 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+
+from annotator.uniformer.mmcv.utils import deprecated_api_warning
+from ..cnn import CONV_LAYERS
+from ..utils import ext_loader, print_log
+
+ext_module = ext_loader.load_ext(
+    '_ext',
+    ['modulated_deform_conv_forward', 'modulated_deform_conv_backward'])
+
+
+class ModulatedDeformConv2dFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input, offset, mask, weight, bias, stride, padding,
+                 dilation, groups, deform_groups):
+        input_tensors = [input, offset, mask, weight]
+        if bias is not None:
+            input_tensors.append(bias)
+        return g.op(
+            'mmcv::MMCVModulatedDeformConv2d',
+            *input_tensors,
+            stride_i=stride,
+            padding_i=padding,
+            dilation_i=dilation,
+            groups_i=groups,
+            deform_groups_i=deform_groups)
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                offset,
+                mask,
+                weight,
+                bias=None,
+                stride=1,
+                padding=0,
+                dilation=1,
+                groups=1,
+                deform_groups=1):
+        if input is not None and input.dim() != 4:
+            raise ValueError(
+                f'Expected 4D tensor as input, got {input.dim()}D tensor \
+                  instead.')
+        ctx.stride = _pair(stride)
+        ctx.padding = _pair(padding)
+        ctx.dilation = _pair(dilation)
+        ctx.groups = groups
+        ctx.deform_groups = deform_groups
+        ctx.with_bias = bias is not None
+        if not ctx.with_bias:
+            bias = input.new_empty(0)  # fake tensor
+        # When pytorch version >= 1.6.0, amp is adopted for fp16 mode;
+        # amp won't cast the type of model (float32), but "offset" is cast
+        # to float16 by nn.Conv2d automatically, leading to the type
+        # mismatch with input (when it is float32) or weight.
+        # The flag for whether to use fp16 or amp is the type of "offset",
+        # we cast weight and input to temporarily support fp16 and amp
+        # whatever the pytorch version is.
+        input = input.type_as(offset)
+        weight = weight.type_as(input)
+        ctx.save_for_backward(input, offset, mask, weight, bias)
+        output = input.new_empty(
+            ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
+        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+        ext_module.modulated_deform_conv_forward(
+            input,
+            weight,
+            bias,
+            ctx._bufs[0],
+            offset,
+            mask,
+            output,
+            ctx._bufs[1],
+            kernel_h=weight.size(2),
+            kernel_w=weight.size(3),
+            stride_h=ctx.stride[0],
+            stride_w=ctx.stride[1],
+            pad_h=ctx.padding[0],
+            pad_w=ctx.padding[1],
+            dilation_h=ctx.dilation[0],
+            dilation_w=ctx.dilation[1],
+            group=ctx.groups,
+            deformable_group=ctx.deform_groups,
+            with_bias=ctx.with_bias)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        input, offset, mask, weight, bias = ctx.saved_tensors
+        grad_input = torch.zeros_like(input)
+        grad_offset = torch.zeros_like(offset)
+        grad_mask = torch.zeros_like(mask)
+        grad_weight = torch.zeros_like(weight)
+        grad_bias = torch.zeros_like(bias)
+        grad_output = grad_output.contiguous()
+        ext_module.modulated_deform_conv_backward(
+            input,
+            weight,
+            bias,
+            ctx._bufs[0],
+            offset,
+            mask,
+            ctx._bufs[1],
+            grad_input,
+            grad_weight,
+            grad_bias,
+            grad_offset,
+            grad_mask,
+            grad_output,
+            kernel_h=weight.size(2),
+            kernel_w=weight.size(3),
+            stride_h=ctx.stride[0],
+            stride_w=ctx.stride[1],
+            pad_h=ctx.padding[0],
+            pad_w=ctx.padding[1],
+            dilation_h=ctx.dilation[0],
+            dilation_w=ctx.dilation[1],
+            group=ctx.groups,
+            deformable_group=ctx.deform_groups,
+            with_bias=ctx.with_bias)
+        if not ctx.with_bias:
+            grad_bias = None
+
+        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
+                None, None, None, None, None)
+
+    @staticmethod
+    def _output_size(ctx, input, weight):
+        channels = weight.size(0)
+        output_size = (input.size(0), channels)
+        for d in range(input.dim() - 2):
+            in_size = input.size(d + 2)
+            pad = ctx.padding[d]
+            kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1
+            stride_ = ctx.stride[d]
+            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+        if not all(map(lambda s: s > 0, output_size)):
+            raise ValueError(
+                'convolution input is too small (output would be ' +
+                'x'.join(map(str, output_size)) + ')')
+        return output_size
+
+
+modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply
+
+
+class ModulatedDeformConv2d(nn.Module):
+
+    @deprecated_api_warning({'deformable_groups': 'deform_groups'},
+                            cls_name='ModulatedDeformConv2d')
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 deform_groups=1,
+                 bias=True):
+        super(ModulatedDeformConv2d, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = _pair(kernel_size)
+        self.stride = _pair(stride)
+        self.padding = _pair(padding)
+        self.dilation = _pair(dilation)
+        self.groups = groups
+        self.deform_groups = deform_groups
+        # enable compatibility with nn.Conv2d
+        self.transposed = False
+        self.output_padding = _single(0)
+
+        self.weight = nn.Parameter(
+            torch.Tensor(out_channels, in_channels // groups,
+                         *self.kernel_size))
+        if bias:
+            self.bias = nn.Parameter(torch.Tensor(out_channels))
+        else:
+            self.register_parameter('bias', None)
+        self.init_weights()
+
+    def init_weights(self):
+        n = self.in_channels
+        for k in self.kernel_size:
+            n *= k
+        stdv = 1. / math.sqrt(n)
+        self.weight.data.uniform_(-stdv, stdv)
+        if self.bias is not None:
+            self.bias.data.zero_()
+
+    def forward(self, x, offset, mask):
+        return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+                                       self.stride, self.padding,
+                                       self.dilation, self.groups,
+                                       self.deform_groups)
+
+
+@CONV_LAYERS.register_module('DCNv2')
+class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
+    """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
+    layers.
+
+    Args:
+        in_channels (int): Same as nn.Conv2d.
+        out_channels (int): Same as nn.Conv2d.
+        kernel_size (int or tuple[int]): Same as nn.Conv2d.
+        stride (int): Same as nn.Conv2d, while tuple is not supported.
+        padding (int): Same as nn.Conv2d, while tuple is not supported.
+        dilation (int): Same as nn.Conv2d, while tuple is not supported.
+        groups (int): Same as nn.Conv2d.
+        bias (bool or str): If specified as `auto`, it will be decided by the
+            norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+            False.
+    """
+
+    _version = 2
+
+    def __init__(self, *args, **kwargs):
+        super(ModulatedDeformConv2dPack, self).__init__(*args, **kwargs)
+        self.conv_offset = nn.Conv2d(
+            self.in_channels,
+            self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+            kernel_size=self.kernel_size,
+            stride=self.stride,
+            padding=self.padding,
+            dilation=self.dilation,
+            bias=True)
+        self.init_weights()
+
+    def init_weights(self):
+        super(ModulatedDeformConv2dPack, self).init_weights()
+        if hasattr(self, 'conv_offset'):
+            self.conv_offset.weight.data.zero_()
+            self.conv_offset.bias.data.zero_()
+
+    def forward(self, x):
+        out = self.conv_offset(x)
+        o1, o2, mask = torch.chunk(out, 3, dim=1)
+        offset = torch.cat((o1, o2), dim=1)
+        mask = torch.sigmoid(mask)
+        return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+                                       self.stride, self.padding,
+                                       self.dilation, self.groups,
+                                       self.deform_groups)
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        version = local_metadata.get('version', None)
+
+        if version is None or version < 2:
+            # the key is different in early versions
+            # In version < 2, ModulatedDeformConvPack
+            # loads previous benchmark models.
+            if (prefix + 'conv_offset.weight' not in state_dict
+                    and prefix[:-1] + '_offset.weight' in state_dict):
+                state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+                    prefix[:-1] + '_offset.weight')
+            if (prefix + 'conv_offset.bias' not in state_dict
+                    and prefix[:-1] + '_offset.bias' in state_dict):
+                state_dict[prefix +
+                           'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+                                                                '_offset.bias')
+
+        if version is not None and version > 1:
+            print_log(
+                f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
+                'version 2.',
+                logger='root')
+
+        super()._load_from_state_dict(state_dict, prefix, local_metadata,
+                                      strict, missing_keys, unexpected_keys,
+                                      error_msgs)
diff --git a/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py b/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c52dda18b41705705b47dd0e995b124048c16fba
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/multi_scale_deform_attn.py
@@ -0,0 +1,358 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd.function import Function, once_differentiable
+
+from annotator.uniformer.mmcv import deprecated_api_warning
+from annotator.uniformer.mmcv.cnn import constant_init, xavier_init
+from annotator.uniformer.mmcv.cnn.bricks.registry import ATTENTION
+from annotator.uniformer.mmcv.runner import BaseModule
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward'])
+
+
+class MultiScaleDeformableAttnFunction(Function):
+
+    @staticmethod
+    def forward(ctx, value, value_spatial_shapes, value_level_start_index,
+                sampling_locations, attention_weights, im2col_step):
+        """GPU version of multi-scale deformable attention.
+
+        Args:
+            value (Tensor): The value has shape
+                (bs, num_keys, mum_heads, embed_dims//num_heads)
+            value_spatial_shapes (Tensor): Spatial shape of
+                each feature map, has shape (num_levels, 2),
+                last dimension 2 represent (h, w)
+            sampling_locations (Tensor): The location of sampling points,
+                has shape
+                (bs ,num_queries, num_heads, num_levels, num_points, 2),
+                the last dimension 2 represent (x, y).
+            attention_weights (Tensor): The weight of sampling points used
+                when calculate the attention, has shape
+                (bs ,num_queries, num_heads, num_levels, num_points),
+            im2col_step (Tensor): The step used in image to column.
+
+        Returns:
+            Tensor: has shape (bs, num_queries, embed_dims)
+        """
+
+        ctx.im2col_step = im2col_step
+        output = ext_module.ms_deform_attn_forward(
+            value,
+            value_spatial_shapes,
+            value_level_start_index,
+            sampling_locations,
+            attention_weights,
+            im2col_step=ctx.im2col_step)
+        ctx.save_for_backward(value, value_spatial_shapes,
+                              value_level_start_index, sampling_locations,
+                              attention_weights)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        """GPU version of backward function.
+
+        Args:
+            grad_output (Tensor): Gradient
+                of output tensor of forward.
+
+        Returns:
+             Tuple[Tensor]: Gradient
+                of input tensors in forward.
+        """
+        value, value_spatial_shapes, value_level_start_index,\
+            sampling_locations, attention_weights = ctx.saved_tensors
+        grad_value = torch.zeros_like(value)
+        grad_sampling_loc = torch.zeros_like(sampling_locations)
+        grad_attn_weight = torch.zeros_like(attention_weights)
+
+        ext_module.ms_deform_attn_backward(
+            value,
+            value_spatial_shapes,
+            value_level_start_index,
+            sampling_locations,
+            attention_weights,
+            grad_output.contiguous(),
+            grad_value,
+            grad_sampling_loc,
+            grad_attn_weight,
+            im2col_step=ctx.im2col_step)
+
+        return grad_value, None, None, \
+            grad_sampling_loc, grad_attn_weight, None
+
+
+def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
+                                        sampling_locations, attention_weights):
+    """CPU version of multi-scale deformable attention.
+
+    Args:
+        value (Tensor): The value has shape
+            (bs, num_keys, mum_heads, embed_dims//num_heads)
+        value_spatial_shapes (Tensor): Spatial shape of
+            each feature map, has shape (num_levels, 2),
+            last dimension 2 represent (h, w)
+        sampling_locations (Tensor): The location of sampling points,
+            has shape
+            (bs ,num_queries, num_heads, num_levels, num_points, 2),
+            the last dimension 2 represent (x, y).
+        attention_weights (Tensor): The weight of sampling points used
+            when calculate the attention, has shape
+            (bs ,num_queries, num_heads, num_levels, num_points),
+
+    Returns:
+        Tensor: has shape (bs, num_queries, embed_dims)
+    """
+
+    bs, _, num_heads, embed_dims = value.shape
+    _, num_queries, num_heads, num_levels, num_points, _ =\
+        sampling_locations.shape
+    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
+                             dim=1)
+    sampling_grids = 2 * sampling_locations - 1
+    sampling_value_list = []
+    for level, (H_, W_) in enumerate(value_spatial_shapes):
+        # bs, H_*W_, num_heads, embed_dims ->
+        # bs, H_*W_, num_heads*embed_dims ->
+        # bs, num_heads*embed_dims, H_*W_ ->
+        # bs*num_heads, embed_dims, H_, W_
+        value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
+            bs * num_heads, embed_dims, H_, W_)
+        # bs, num_queries, num_heads, num_points, 2 ->
+        # bs, num_heads, num_queries, num_points, 2 ->
+        # bs*num_heads, num_queries, num_points, 2
+        sampling_grid_l_ = sampling_grids[:, :, :,
+                                          level].transpose(1, 2).flatten(0, 1)
+        # bs*num_heads, embed_dims, num_queries, num_points
+        sampling_value_l_ = F.grid_sample(
+            value_l_,
+            sampling_grid_l_,
+            mode='bilinear',
+            padding_mode='zeros',
+            align_corners=False)
+        sampling_value_list.append(sampling_value_l_)
+    # (bs, num_queries, num_heads, num_levels, num_points) ->
+    # (bs, num_heads, num_queries, num_levels, num_points) ->
+    # (bs, num_heads, 1, num_queries, num_levels*num_points)
+    attention_weights = attention_weights.transpose(1, 2).reshape(
+        bs * num_heads, 1, num_queries, num_levels * num_points)
+    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
+              attention_weights).sum(-1).view(bs, num_heads * embed_dims,
+                                              num_queries)
+    return output.transpose(1, 2).contiguous()
+
+
+@ATTENTION.register_module()
+class MultiScaleDeformableAttention(BaseModule):
+    """An attention module used in Deformable-Detr.
+
+    `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
+    <https://arxiv.org/pdf/2010.04159.pdf>`_.
+
+    Args:
+        embed_dims (int): The embedding dimension of Attention.
+            Default: 256.
+        num_heads (int): Parallel attention heads. Default: 64.
+        num_levels (int): The number of feature map used in
+            Attention. Default: 4.
+        num_points (int): The number of sampling points for
+            each query in each head. Default: 4.
+        im2col_step (int): The step used in image_to_column.
+            Default: 64.
+        dropout (float): A Dropout layer on `inp_identity`.
+            Default: 0.1.
+        batch_first (bool): Key, Query and Value are shape of
+            (batch, n, embed_dim)
+            or (n, batch, embed_dim). Default to False.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: None.
+        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+            Default: None.
+    """
+
+    def __init__(self,
+                 embed_dims=256,
+                 num_heads=8,
+                 num_levels=4,
+                 num_points=4,
+                 im2col_step=64,
+                 dropout=0.1,
+                 batch_first=False,
+                 norm_cfg=None,
+                 init_cfg=None):
+        super().__init__(init_cfg)
+        if embed_dims % num_heads != 0:
+            raise ValueError(f'embed_dims must be divisible by num_heads, '
+                             f'but got {embed_dims} and {num_heads}')
+        dim_per_head = embed_dims // num_heads
+        self.norm_cfg = norm_cfg
+        self.dropout = nn.Dropout(dropout)
+        self.batch_first = batch_first
+
+        # you'd better set dim_per_head to a power of 2
+        # which is more efficient in the CUDA implementation
+        def _is_power_of_2(n):
+            if (not isinstance(n, int)) or (n < 0):
+                raise ValueError(
+                    'invalid input for _is_power_of_2: {} (type: {})'.format(
+                        n, type(n)))
+            return (n & (n - 1) == 0) and n != 0
+
+        if not _is_power_of_2(dim_per_head):
+            warnings.warn(
+                "You'd better set embed_dims in "
+                'MultiScaleDeformAttention to make '
+                'the dimension of each attention head a power of 2 '
+                'which is more efficient in our CUDA implementation.')
+
+        self.im2col_step = im2col_step
+        self.embed_dims = embed_dims
+        self.num_levels = num_levels
+        self.num_heads = num_heads
+        self.num_points = num_points
+        self.sampling_offsets = nn.Linear(
+            embed_dims, num_heads * num_levels * num_points * 2)
+        self.attention_weights = nn.Linear(embed_dims,
+                                           num_heads * num_levels * num_points)
+        self.value_proj = nn.Linear(embed_dims, embed_dims)
+        self.output_proj = nn.Linear(embed_dims, embed_dims)
+        self.init_weights()
+
+    def init_weights(self):
+        """Default initialization for Parameters of Module."""
+        constant_init(self.sampling_offsets, 0.)
+        thetas = torch.arange(
+            self.num_heads,
+            dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
+        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+        grid_init = (grid_init /
+                     grid_init.abs().max(-1, keepdim=True)[0]).view(
+                         self.num_heads, 1, 1,
+                         2).repeat(1, self.num_levels, self.num_points, 1)
+        for i in range(self.num_points):
+            grid_init[:, :, i, :] *= i + 1
+
+        self.sampling_offsets.bias.data = grid_init.view(-1)
+        constant_init(self.attention_weights, val=0., bias=0.)
+        xavier_init(self.value_proj, distribution='uniform', bias=0.)
+        xavier_init(self.output_proj, distribution='uniform', bias=0.)
+        self._is_init = True
+
+    @deprecated_api_warning({'residual': 'identity'},
+                            cls_name='MultiScaleDeformableAttention')
+    def forward(self,
+                query,
+                key=None,
+                value=None,
+                identity=None,
+                query_pos=None,
+                key_padding_mask=None,
+                reference_points=None,
+                spatial_shapes=None,
+                level_start_index=None,
+                **kwargs):
+        """Forward Function of MultiScaleDeformAttention.
+
+        Args:
+            query (Tensor): Query of Transformer with shape
+                (num_query, bs, embed_dims).
+            key (Tensor): The key tensor with shape
+                `(num_key, bs, embed_dims)`.
+            value (Tensor): The value tensor with shape
+                `(num_key, bs, embed_dims)`.
+            identity (Tensor): The tensor used for addition, with the
+                same shape as `query`. Default None. If None,
+                `query` will be used.
+            query_pos (Tensor): The positional encoding for `query`.
+                Default: None.
+            key_pos (Tensor): The positional encoding for `key`. Default
+                None.
+            reference_points (Tensor):  The normalized reference
+                points with shape (bs, num_query, num_levels, 2),
+                all elements is range in [0, 1], top-left (0,0),
+                bottom-right (1, 1), including padding area.
+                or (N, Length_{query}, num_levels, 4), add
+                additional two dimensions is (w, h) to
+                form reference boxes.
+            key_padding_mask (Tensor): ByteTensor for `query`, with
+                shape [bs, num_key].
+            spatial_shapes (Tensor): Spatial shape of features in
+                different levels. With shape (num_levels, 2),
+                last dimension represents (h, w).
+            level_start_index (Tensor): The start index of each level.
+                A tensor has shape ``(num_levels, )`` and can be represented
+                as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+
+        Returns:
+             Tensor: forwarded results with shape [num_query, bs, embed_dims].
+        """
+
+        if value is None:
+            value = query
+
+        if identity is None:
+            identity = query
+        if query_pos is not None:
+            query = query + query_pos
+        if not self.batch_first:
+            # change to (bs, num_query ,embed_dims)
+            query = query.permute(1, 0, 2)
+            value = value.permute(1, 0, 2)
+
+        bs, num_query, _ = query.shape
+        bs, num_value, _ = value.shape
+        assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+
+        value = self.value_proj(value)
+        if key_padding_mask is not None:
+            value = value.masked_fill(key_padding_mask[..., None], 0.0)
+        value = value.view(bs, num_value, self.num_heads, -1)
+        sampling_offsets = self.sampling_offsets(query).view(
+            bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
+        attention_weights = self.attention_weights(query).view(
+            bs, num_query, self.num_heads, self.num_levels * self.num_points)
+        attention_weights = attention_weights.softmax(-1)
+
+        attention_weights = attention_weights.view(bs, num_query,
+                                                   self.num_heads,
+                                                   self.num_levels,
+                                                   self.num_points)
+        if reference_points.shape[-1] == 2:
+            offset_normalizer = torch.stack(
+                [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+            sampling_locations = reference_points[:, :, None, :, None, :] \
+                + sampling_offsets \
+                / offset_normalizer[None, None, None, :, None, :]
+        elif reference_points.shape[-1] == 4:
+            sampling_locations = reference_points[:, :, None, :, None, :2] \
+                + sampling_offsets / self.num_points \
+                * reference_points[:, :, None, :, None, 2:] \
+                * 0.5
+        else:
+            raise ValueError(
+                f'Last dim of reference_points must be'
+                f' 2 or 4, but get {reference_points.shape[-1]} instead.')
+        if torch.cuda.is_available() and value.is_cuda:
+            output = MultiScaleDeformableAttnFunction.apply(
+                value, spatial_shapes, level_start_index, sampling_locations,
+                attention_weights, self.im2col_step)
+        else:
+            output = multi_scale_deformable_attn_pytorch(
+                value, spatial_shapes, sampling_locations, attention_weights)
+
+        output = self.output_proj(output)
+
+        if not self.batch_first:
+            # (num_query, bs ,embed_dims)
+            output = output.permute(1, 0, 2)
+
+        return self.dropout(output) + identity
diff --git a/annotator/uniformer/mmcv/ops/nms.py b/annotator/uniformer/mmcv/ops/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d9634281f486ab284091786886854c451368052
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/nms.py
@@ -0,0 +1,417 @@
+import os
+
+import numpy as np
+import torch
+
+from annotator.uniformer.mmcv.utils import deprecated_api_warning
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated'])
+
+
+# This function is modified from: https://github.com/pytorch/vision/
+class NMSop(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, bboxes, scores, iou_threshold, offset, score_threshold,
+                max_num):
+        is_filtering_by_score = score_threshold > 0
+        if is_filtering_by_score:
+            valid_mask = scores > score_threshold
+            bboxes, scores = bboxes[valid_mask], scores[valid_mask]
+            valid_inds = torch.nonzero(
+                valid_mask, as_tuple=False).squeeze(dim=1)
+
+        inds = ext_module.nms(
+            bboxes, scores, iou_threshold=float(iou_threshold), offset=offset)
+
+        if max_num > 0:
+            inds = inds[:max_num]
+        if is_filtering_by_score:
+            inds = valid_inds[inds]
+        return inds
+
+    @staticmethod
+    def symbolic(g, bboxes, scores, iou_threshold, offset, score_threshold,
+                 max_num):
+        from ..onnx import is_custom_op_loaded
+        has_custom_op = is_custom_op_loaded()
+        # TensorRT nms plugin is aligned with original nms in ONNXRuntime
+        is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
+        if has_custom_op and (not is_trt_backend):
+            return g.op(
+                'mmcv::NonMaxSuppression',
+                bboxes,
+                scores,
+                iou_threshold_f=float(iou_threshold),
+                offset_i=int(offset))
+        else:
+            from torch.onnx.symbolic_opset9 import select, squeeze, unsqueeze
+            from ..onnx.onnx_utils.symbolic_helper import _size_helper
+
+            boxes = unsqueeze(g, bboxes, 0)
+            scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
+
+            if max_num > 0:
+                max_num = g.op(
+                    'Constant',
+                    value_t=torch.tensor(max_num, dtype=torch.long))
+            else:
+                dim = g.op('Constant', value_t=torch.tensor(0))
+                max_num = _size_helper(g, bboxes, dim)
+            max_output_per_class = max_num
+            iou_threshold = g.op(
+                'Constant',
+                value_t=torch.tensor([iou_threshold], dtype=torch.float))
+            score_threshold = g.op(
+                'Constant',
+                value_t=torch.tensor([score_threshold], dtype=torch.float))
+            nms_out = g.op('NonMaxSuppression', boxes, scores,
+                           max_output_per_class, iou_threshold,
+                           score_threshold)
+            return squeeze(
+                g,
+                select(
+                    g, nms_out, 1,
+                    g.op(
+                        'Constant',
+                        value_t=torch.tensor([2], dtype=torch.long))), 1)
+
+
+class SoftNMSop(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, boxes, scores, iou_threshold, sigma, min_score, method,
+                offset):
+        dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+        inds = ext_module.softnms(
+            boxes.cpu(),
+            scores.cpu(),
+            dets.cpu(),
+            iou_threshold=float(iou_threshold),
+            sigma=float(sigma),
+            min_score=float(min_score),
+            method=int(method),
+            offset=int(offset))
+        return dets, inds
+
+    @staticmethod
+    def symbolic(g, boxes, scores, iou_threshold, sigma, min_score, method,
+                 offset):
+        from packaging import version
+        assert version.parse(torch.__version__) >= version.parse('1.7.0')
+        nms_out = g.op(
+            'mmcv::SoftNonMaxSuppression',
+            boxes,
+            scores,
+            iou_threshold_f=float(iou_threshold),
+            sigma_f=float(sigma),
+            min_score_f=float(min_score),
+            method_i=int(method),
+            offset_i=int(offset),
+            outputs=2)
+        return nms_out
+
+
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def nms(boxes, scores, iou_threshold, offset=0, score_threshold=0, max_num=-1):
+    """Dispatch to either CPU or GPU NMS implementations.
+
+    The input can be either torch tensor or numpy array. GPU NMS will be used
+    if the input is gpu tensor, otherwise CPU NMS
+    will be used. The returned type will always be the same as inputs.
+
+    Arguments:
+        boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+        scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+        iou_threshold (float): IoU threshold for NMS.
+        offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+        score_threshold (float): score threshold for NMS.
+        max_num (int): maximum number of boxes after NMS.
+
+    Returns:
+        tuple: kept dets(boxes and scores) and indice, which is always the \
+            same data type as the input.
+
+    Example:
+        >>> boxes = np.array([[49.1, 32.4, 51.0, 35.9],
+        >>>                   [49.3, 32.9, 51.0, 35.3],
+        >>>                   [49.2, 31.8, 51.0, 35.4],
+        >>>                   [35.1, 11.5, 39.1, 15.7],
+        >>>                   [35.6, 11.8, 39.3, 14.2],
+        >>>                   [35.3, 11.5, 39.9, 14.5],
+        >>>                   [35.2, 11.7, 39.7, 15.7]], dtype=np.float32)
+        >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.5, 0.4, 0.3],\
+               dtype=np.float32)
+        >>> iou_threshold = 0.6
+        >>> dets, inds = nms(boxes, scores, iou_threshold)
+        >>> assert len(inds) == len(dets) == 3
+    """
+    assert isinstance(boxes, (torch.Tensor, np.ndarray))
+    assert isinstance(scores, (torch.Tensor, np.ndarray))
+    is_numpy = False
+    if isinstance(boxes, np.ndarray):
+        is_numpy = True
+        boxes = torch.from_numpy(boxes)
+    if isinstance(scores, np.ndarray):
+        scores = torch.from_numpy(scores)
+    assert boxes.size(1) == 4
+    assert boxes.size(0) == scores.size(0)
+    assert offset in (0, 1)
+
+    if torch.__version__ == 'parrots':
+        indata_list = [boxes, scores]
+        indata_dict = {
+            'iou_threshold': float(iou_threshold),
+            'offset': int(offset)
+        }
+        inds = ext_module.nms(*indata_list, **indata_dict)
+    else:
+        inds = NMSop.apply(boxes, scores, iou_threshold, offset,
+                           score_threshold, max_num)
+    dets = torch.cat((boxes[inds], scores[inds].reshape(-1, 1)), dim=1)
+    if is_numpy:
+        dets = dets.cpu().numpy()
+        inds = inds.cpu().numpy()
+    return dets, inds
+
+
+@deprecated_api_warning({'iou_thr': 'iou_threshold'})
+def soft_nms(boxes,
+             scores,
+             iou_threshold=0.3,
+             sigma=0.5,
+             min_score=1e-3,
+             method='linear',
+             offset=0):
+    """Dispatch to only CPU Soft NMS implementations.
+
+    The input can be either a torch tensor or numpy array.
+    The returned type will always be the same as inputs.
+
+    Arguments:
+        boxes (torch.Tensor or np.ndarray): boxes in shape (N, 4).
+        scores (torch.Tensor or np.ndarray): scores in shape (N, ).
+        iou_threshold (float): IoU threshold for NMS.
+        sigma (float): hyperparameter for gaussian method
+        min_score (float): score filter threshold
+        method (str): either 'linear' or 'gaussian'
+        offset (int, 0 or 1): boxes' width or height is (x2 - x1 + offset).
+
+    Returns:
+        tuple: kept dets(boxes and scores) and indice, which is always the \
+            same data type as the input.
+
+    Example:
+        >>> boxes = np.array([[4., 3., 5., 3.],
+        >>>                   [4., 3., 5., 4.],
+        >>>                   [3., 1., 3., 1.],
+        >>>                   [3., 1., 3., 1.],
+        >>>                   [3., 1., 3., 1.],
+        >>>                   [3., 1., 3., 1.]], dtype=np.float32)
+        >>> scores = np.array([0.9, 0.9, 0.5, 0.5, 0.4, 0.0], dtype=np.float32)
+        >>> iou_threshold = 0.6
+        >>> dets, inds = soft_nms(boxes, scores, iou_threshold, sigma=0.5)
+        >>> assert len(inds) == len(dets) == 5
+    """
+
+    assert isinstance(boxes, (torch.Tensor, np.ndarray))
+    assert isinstance(scores, (torch.Tensor, np.ndarray))
+    is_numpy = False
+    if isinstance(boxes, np.ndarray):
+        is_numpy = True
+        boxes = torch.from_numpy(boxes)
+    if isinstance(scores, np.ndarray):
+        scores = torch.from_numpy(scores)
+    assert boxes.size(1) == 4
+    assert boxes.size(0) == scores.size(0)
+    assert offset in (0, 1)
+    method_dict = {'naive': 0, 'linear': 1, 'gaussian': 2}
+    assert method in method_dict.keys()
+
+    if torch.__version__ == 'parrots':
+        dets = boxes.new_empty((boxes.size(0), 5), device='cpu')
+        indata_list = [boxes.cpu(), scores.cpu(), dets.cpu()]
+        indata_dict = {
+            'iou_threshold': float(iou_threshold),
+            'sigma': float(sigma),
+            'min_score': min_score,
+            'method': method_dict[method],
+            'offset': int(offset)
+        }
+        inds = ext_module.softnms(*indata_list, **indata_dict)
+    else:
+        dets, inds = SoftNMSop.apply(boxes.cpu(), scores.cpu(),
+                                     float(iou_threshold), float(sigma),
+                                     float(min_score), method_dict[method],
+                                     int(offset))
+
+    dets = dets[:inds.size(0)]
+
+    if is_numpy:
+        dets = dets.cpu().numpy()
+        inds = inds.cpu().numpy()
+        return dets, inds
+    else:
+        return dets.to(device=boxes.device), inds.to(device=boxes.device)
+
+
+def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
+    """Performs non-maximum suppression in a batched fashion.
+
+    Modified from https://github.com/pytorch/vision/blob
+    /505cd6957711af790211896d32b40291bea1bc21/torchvision/ops/boxes.py#L39.
+    In order to perform NMS independently per class, we add an offset to all
+    the boxes. The offset is dependent only on the class idx, and is large
+    enough so that boxes from different classes do not overlap.
+
+    Arguments:
+        boxes (torch.Tensor): boxes in shape (N, 4).
+        scores (torch.Tensor): scores in shape (N, ).
+        idxs (torch.Tensor): each index value correspond to a bbox cluster,
+            and NMS will not be applied between elements of different idxs,
+            shape (N, ).
+        nms_cfg (dict): specify nms type and other parameters like iou_thr.
+            Possible keys includes the following.
+
+            - iou_thr (float): IoU threshold used for NMS.
+            - split_thr (float): threshold number of boxes. In some cases the
+                number of boxes is large (e.g., 200k). To avoid OOM during
+                training, the users could set `split_thr` to a small value.
+                If the number of boxes is greater than the threshold, it will
+                perform NMS on each group of boxes separately and sequentially.
+                Defaults to 10000.
+        class_agnostic (bool): if true, nms is class agnostic,
+            i.e. IoU thresholding happens over all boxes,
+            regardless of the predicted class.
+
+    Returns:
+        tuple: kept dets and indice.
+    """
+    nms_cfg_ = nms_cfg.copy()
+    class_agnostic = nms_cfg_.pop('class_agnostic', class_agnostic)
+    if class_agnostic:
+        boxes_for_nms = boxes
+    else:
+        max_coordinate = boxes.max()
+        offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
+        boxes_for_nms = boxes + offsets[:, None]
+
+    nms_type = nms_cfg_.pop('type', 'nms')
+    nms_op = eval(nms_type)
+
+    split_thr = nms_cfg_.pop('split_thr', 10000)
+    # Won't split to multiple nms nodes when exporting to onnx
+    if boxes_for_nms.shape[0] < split_thr or torch.onnx.is_in_onnx_export():
+        dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
+        boxes = boxes[keep]
+        # -1 indexing works abnormal in TensorRT
+        # This assumes `dets` has 5 dimensions where
+        # the last dimension is score.
+        # TODO: more elegant way to handle the dimension issue.
+        # Some type of nms would reweight the score, such as SoftNMS
+        scores = dets[:, 4]
+    else:
+        max_num = nms_cfg_.pop('max_num', -1)
+        total_mask = scores.new_zeros(scores.size(), dtype=torch.bool)
+        # Some type of nms would reweight the score, such as SoftNMS
+        scores_after_nms = scores.new_zeros(scores.size())
+        for id in torch.unique(idxs):
+            mask = (idxs == id).nonzero(as_tuple=False).view(-1)
+            dets, keep = nms_op(boxes_for_nms[mask], scores[mask], **nms_cfg_)
+            total_mask[mask[keep]] = True
+            scores_after_nms[mask[keep]] = dets[:, -1]
+        keep = total_mask.nonzero(as_tuple=False).view(-1)
+
+        scores, inds = scores_after_nms[keep].sort(descending=True)
+        keep = keep[inds]
+        boxes = boxes[keep]
+
+        if max_num > 0:
+            keep = keep[:max_num]
+            boxes = boxes[:max_num]
+            scores = scores[:max_num]
+
+    return torch.cat([boxes, scores[:, None]], -1), keep
+
+
+def nms_match(dets, iou_threshold):
+    """Matched dets into different groups by NMS.
+
+    NMS match is Similar to NMS but when a bbox is suppressed, nms match will
+    record the indice of suppressed bbox and form a group with the indice of
+    kept bbox. In each group, indice is sorted as score order.
+
+    Arguments:
+        dets (torch.Tensor | np.ndarray): Det boxes with scores, shape (N, 5).
+        iou_thr (float): IoU thresh for NMS.
+
+    Returns:
+        List[torch.Tensor | np.ndarray]: The outer list corresponds different
+            matched group, the inner Tensor corresponds the indices for a group
+            in score order.
+    """
+    if dets.shape[0] == 0:
+        matched = []
+    else:
+        assert dets.shape[-1] == 5, 'inputs dets.shape should be (N, 5), ' \
+                                    f'but get {dets.shape}'
+        if isinstance(dets, torch.Tensor):
+            dets_t = dets.detach().cpu()
+        else:
+            dets_t = torch.from_numpy(dets)
+        indata_list = [dets_t]
+        indata_dict = {'iou_threshold': float(iou_threshold)}
+        matched = ext_module.nms_match(*indata_list, **indata_dict)
+        if torch.__version__ == 'parrots':
+            matched = matched.tolist()
+
+    if isinstance(dets, torch.Tensor):
+        return [dets.new_tensor(m, dtype=torch.long) for m in matched]
+    else:
+        return [np.array(m, dtype=np.int) for m in matched]
+
+
+def nms_rotated(dets, scores, iou_threshold, labels=None):
+    """Performs non-maximum suppression (NMS) on the rotated boxes according to
+    their intersection-over-union (IoU).
+
+    Rotated NMS iteratively removes lower scoring rotated boxes which have an
+    IoU greater than iou_threshold with another (higher scoring) rotated box.
+
+    Args:
+        boxes (Tensor):  Rotated boxes in shape (N, 5). They are expected to \
+            be in (x_ctr, y_ctr, width, height, angle_radian) format.
+        scores (Tensor): scores in shape (N, ).
+        iou_threshold (float): IoU thresh for NMS.
+        labels (Tensor): boxes' label in shape (N,).
+
+    Returns:
+        tuple: kept dets(boxes and scores) and indice, which is always the \
+            same data type as the input.
+    """
+    if dets.shape[0] == 0:
+        return dets, None
+    multi_label = labels is not None
+    if multi_label:
+        dets_wl = torch.cat((dets, labels.unsqueeze(1)), 1)
+    else:
+        dets_wl = dets
+    _, order = scores.sort(0, descending=True)
+    dets_sorted = dets_wl.index_select(0, order)
+
+    if torch.__version__ == 'parrots':
+        keep_inds = ext_module.nms_rotated(
+            dets_wl,
+            scores,
+            order,
+            dets_sorted,
+            iou_threshold=iou_threshold,
+            multi_label=multi_label)
+    else:
+        keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
+                                           iou_threshold, multi_label)
+    dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
+                     dim=1)
+    return dets, keep_inds
diff --git a/annotator/uniformer/mmcv/ops/pixel_group.py b/annotator/uniformer/mmcv/ops/pixel_group.py
new file mode 100644
index 0000000000000000000000000000000000000000..2143c75f835a467c802fc3c37ecd3ac0f85bcda4
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/pixel_group.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['pixel_group'])
+
+
+def pixel_group(score, mask, embedding, kernel_label, kernel_contour,
+                kernel_region_num, distance_threshold):
+    """Group pixels into text instances, which is widely used text detection
+    methods.
+
+    Arguments:
+        score (np.array or Tensor): The foreground score with size hxw.
+        mask (np.array or Tensor): The foreground mask with size hxw.
+        embedding (np.array or Tensor): The embedding with size hxwxc to
+            distinguish instances.
+        kernel_label (np.array or Tensor): The instance kernel index with
+            size hxw.
+        kernel_contour (np.array or Tensor): The kernel contour with size hxw.
+        kernel_region_num (int): The instance kernel region number.
+        distance_threshold (float): The embedding distance threshold between
+            kernel and pixel in one instance.
+
+    Returns:
+        pixel_assignment (List[List[float]]): The instance coordinate list.
+            Each element consists of averaged confidence, pixel number, and
+            coordinates (x_i, y_i for all pixels) in order.
+    """
+    assert isinstance(score, (torch.Tensor, np.ndarray))
+    assert isinstance(mask, (torch.Tensor, np.ndarray))
+    assert isinstance(embedding, (torch.Tensor, np.ndarray))
+    assert isinstance(kernel_label, (torch.Tensor, np.ndarray))
+    assert isinstance(kernel_contour, (torch.Tensor, np.ndarray))
+    assert isinstance(kernel_region_num, int)
+    assert isinstance(distance_threshold, float)
+
+    if isinstance(score, np.ndarray):
+        score = torch.from_numpy(score)
+    if isinstance(mask, np.ndarray):
+        mask = torch.from_numpy(mask)
+    if isinstance(embedding, np.ndarray):
+        embedding = torch.from_numpy(embedding)
+    if isinstance(kernel_label, np.ndarray):
+        kernel_label = torch.from_numpy(kernel_label)
+    if isinstance(kernel_contour, np.ndarray):
+        kernel_contour = torch.from_numpy(kernel_contour)
+
+    if torch.__version__ == 'parrots':
+        label = ext_module.pixel_group(
+            score,
+            mask,
+            embedding,
+            kernel_label,
+            kernel_contour,
+            kernel_region_num=kernel_region_num,
+            distance_threshold=distance_threshold)
+        label = label.tolist()
+        label = label[0]
+        list_index = kernel_region_num
+        pixel_assignment = []
+        for x in range(kernel_region_num):
+            pixel_assignment.append(
+                np.array(
+                    label[list_index:list_index + int(label[x])],
+                    dtype=np.float))
+            list_index = list_index + int(label[x])
+    else:
+        pixel_assignment = ext_module.pixel_group(score, mask, embedding,
+                                                  kernel_label, kernel_contour,
+                                                  kernel_region_num,
+                                                  distance_threshold)
+    return pixel_assignment
diff --git a/annotator/uniformer/mmcv/ops/point_sample.py b/annotator/uniformer/mmcv/ops/point_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..267f4b3c56630acd85f9bdc630b7be09abab0aba
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/point_sample.py
@@ -0,0 +1,336 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend  # noqa
+
+from os import path as osp
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.utils import _pair
+from torch.onnx.operators import shape_as_tensor
+
+
+def bilinear_grid_sample(im, grid, align_corners=False):
+    """Given an input and a flow-field grid, computes the output using input
+    values and pixel locations from grid. Supported only bilinear interpolation
+    method to sample the input pixels.
+
+    Args:
+        im (torch.Tensor): Input feature map, shape (N, C, H, W)
+        grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
+        align_corners {bool}: If set to True, the extrema (-1 and 1) are
+            considered as referring to the center points of the input’s
+            corner pixels. If set to False, they are instead considered as
+            referring to the corner points of the input’s corner pixels,
+            making the sampling more resolution agnostic.
+    Returns:
+        torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
+    """
+    n, c, h, w = im.shape
+    gn, gh, gw, _ = grid.shape
+    assert n == gn
+
+    x = grid[:, :, :, 0]
+    y = grid[:, :, :, 1]
+
+    if align_corners:
+        x = ((x + 1) / 2) * (w - 1)
+        y = ((y + 1) / 2) * (h - 1)
+    else:
+        x = ((x + 1) * w - 1) / 2
+        y = ((y + 1) * h - 1) / 2
+
+    x = x.view(n, -1)
+    y = y.view(n, -1)
+
+    x0 = torch.floor(x).long()
+    y0 = torch.floor(y).long()
+    x1 = x0 + 1
+    y1 = y0 + 1
+
+    wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
+    wb = ((x1 - x) * (y - y0)).unsqueeze(1)
+    wc = ((x - x0) * (y1 - y)).unsqueeze(1)
+    wd = ((x - x0) * (y - y0)).unsqueeze(1)
+
+    # Apply default for grid_sample function zero padding
+    im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
+    padded_h = h + 2
+    padded_w = w + 2
+    # save points positions after padding
+    x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
+
+    # Clip coordinates to padded image size
+    x0 = torch.where(x0 < 0, torch.tensor(0), x0)
+    x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
+    x1 = torch.where(x1 < 0, torch.tensor(0), x1)
+    x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
+    y0 = torch.where(y0 < 0, torch.tensor(0), y0)
+    y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
+    y1 = torch.where(y1 < 0, torch.tensor(0), y1)
+    y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
+
+    im_padded = im_padded.view(n, c, -1)
+
+    x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+    x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+    x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
+    x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
+
+    Ia = torch.gather(im_padded, 2, x0_y0)
+    Ib = torch.gather(im_padded, 2, x0_y1)
+    Ic = torch.gather(im_padded, 2, x1_y0)
+    Id = torch.gather(im_padded, 2, x1_y1)
+
+    return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
+
+
+def is_in_onnx_export_without_custom_ops():
+    from annotator.uniformer.mmcv.ops import get_onnxruntime_op_path
+    ort_custom_op_path = get_onnxruntime_op_path()
+    return torch.onnx.is_in_onnx_export(
+    ) and not osp.exists(ort_custom_op_path)
+
+
+def normalize(grid):
+    """Normalize input grid from [-1, 1] to [0, 1]
+    Args:
+        grid (Tensor): The grid to be normalize, range [-1, 1].
+    Returns:
+        Tensor: Normalized grid, range [0, 1].
+    """
+
+    return (grid + 1.0) / 2.0
+
+
+def denormalize(grid):
+    """Denormalize input grid from range [0, 1] to [-1, 1]
+    Args:
+        grid (Tensor): The grid to be denormalize, range [0, 1].
+    Returns:
+        Tensor: Denormalized grid, range [-1, 1].
+    """
+
+    return grid * 2.0 - 1.0
+
+
+def generate_grid(num_grid, size, device):
+    """Generate regular square grid of points in [0, 1] x [0, 1] coordinate
+    space.
+
+    Args:
+        num_grid (int): The number of grids to sample, one for each region.
+        size (tuple(int, int)): The side size of the regular grid.
+        device (torch.device): Desired device of returned tensor.
+
+    Returns:
+        (torch.Tensor): A tensor of shape (num_grid, size[0]*size[1], 2) that
+            contains coordinates for the regular grids.
+    """
+
+    affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device)
+    grid = F.affine_grid(
+        affine_trans, torch.Size((1, 1, *size)), align_corners=False)
+    grid = normalize(grid)
+    return grid.view(1, -1, 2).expand(num_grid, -1, -1)
+
+
+def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
+    """Convert roi based relative point coordinates to image based absolute
+    point coordinates.
+
+    Args:
+        rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+        rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+            RoI, location, range (0, 1), shape (N, P, 2)
+    Returns:
+        Tensor: Image based absolute point coordinates, shape (N, P, 2)
+    """
+
+    with torch.no_grad():
+        assert rel_roi_points.size(0) == rois.size(0)
+        assert rois.dim() == 2
+        assert rel_roi_points.dim() == 3
+        assert rel_roi_points.size(2) == 2
+        # remove batch idx
+        if rois.size(1) == 5:
+            rois = rois[:, 1:]
+        abs_img_points = rel_roi_points.clone()
+        # To avoid an error during exporting to onnx use independent
+        # variables instead inplace computation
+        xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0])
+        ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1])
+        xs += rois[:, None, 0]
+        ys += rois[:, None, 1]
+        abs_img_points = torch.stack([xs, ys], dim=2)
+    return abs_img_points
+
+
+def get_shape_from_feature_map(x):
+    """Get spatial resolution of input feature map considering exporting to
+    onnx mode.
+
+    Args:
+        x (torch.Tensor): Input tensor, shape (N, C, H, W)
+    Returns:
+        torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
+    """
+    if torch.onnx.is_in_onnx_export():
+        img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to(
+            x.device).float()
+    else:
+        img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
+            x.device).float()
+    return img_shape
+
+
+def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
+    """Convert image based absolute point coordinates to image based relative
+    coordinates for sampling.
+
+    Args:
+        abs_img_points (Tensor): Image based absolute point coordinates,
+            shape (N, P, 2)
+        img (tuple/Tensor): (height, width) of image or feature map.
+        spatial_scale (float): Scale points by this factor. Default: 1.
+
+    Returns:
+        Tensor: Image based relative point coordinates for sampling,
+            shape (N, P, 2)
+    """
+
+    assert (isinstance(img, tuple) and len(img) == 2) or \
+           (isinstance(img, torch.Tensor) and len(img.shape) == 4)
+
+    if isinstance(img, tuple):
+        h, w = img
+        scale = torch.tensor([w, h],
+                             dtype=torch.float,
+                             device=abs_img_points.device)
+        scale = scale.view(1, 1, 2)
+    else:
+        scale = get_shape_from_feature_map(img)
+
+    return abs_img_points / scale * spatial_scale
+
+
+def rel_roi_point_to_rel_img_point(rois,
+                                   rel_roi_points,
+                                   img,
+                                   spatial_scale=1.):
+    """Convert roi based relative point coordinates to image based absolute
+    point coordinates.
+
+    Args:
+        rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
+        rel_roi_points (Tensor): Point coordinates inside RoI, relative to
+            RoI, location, range (0, 1), shape (N, P, 2)
+        img (tuple/Tensor): (height, width) of image or feature map.
+        spatial_scale (float): Scale points by this factor. Default: 1.
+
+    Returns:
+        Tensor: Image based relative point coordinates for sampling,
+            shape (N, P, 2)
+    """
+
+    abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points)
+    rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img,
+                                                   spatial_scale)
+
+    return rel_img_point
+
+
+def point_sample(input, points, align_corners=False, **kwargs):
+    """A wrapper around :func:`grid_sample` to support 3D point_coords tensors
+    Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
+    lie inside ``[0, 1] x [0, 1]`` square.
+
+    Args:
+        input (Tensor): Feature map, shape (N, C, H, W).
+        points (Tensor): Image based absolute point coordinates (normalized),
+            range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2).
+        align_corners (bool): Whether align_corners. Default: False
+
+    Returns:
+        Tensor: Features of `point` on `input`, shape (N, C, P) or
+            (N, C, Hgrid, Wgrid).
+    """
+
+    add_dim = False
+    if points.dim() == 3:
+        add_dim = True
+        points = points.unsqueeze(2)
+    if is_in_onnx_export_without_custom_ops():
+        # If custom ops for onnx runtime not compiled use python
+        # implementation of grid_sample function to make onnx graph
+        # with supported nodes
+        output = bilinear_grid_sample(
+            input, denormalize(points), align_corners=align_corners)
+    else:
+        output = F.grid_sample(
+            input, denormalize(points), align_corners=align_corners, **kwargs)
+    if add_dim:
+        output = output.squeeze(3)
+    return output
+
+
+class SimpleRoIAlign(nn.Module):
+
+    def __init__(self, output_size, spatial_scale, aligned=True):
+        """Simple RoI align in PointRend, faster than standard RoIAlign.
+
+        Args:
+            output_size (tuple[int]): h, w
+            spatial_scale (float): scale the input boxes by this number
+            aligned (bool): if False, use the legacy implementation in
+                MMDetection, align_corners=True will be used in F.grid_sample.
+                If True, align the results more perfectly.
+        """
+
+        super(SimpleRoIAlign, self).__init__()
+        self.output_size = _pair(output_size)
+        self.spatial_scale = float(spatial_scale)
+        # to be consistent with other RoI ops
+        self.use_torchvision = False
+        self.aligned = aligned
+
+    def forward(self, features, rois):
+        num_imgs = features.size(0)
+        num_rois = rois.size(0)
+        rel_roi_points = generate_grid(
+            num_rois, self.output_size, device=rois.device)
+
+        if torch.onnx.is_in_onnx_export():
+            rel_img_points = rel_roi_point_to_rel_img_point(
+                rois, rel_roi_points, features, self.spatial_scale)
+            rel_img_points = rel_img_points.reshape(num_imgs, -1,
+                                                    *rel_img_points.shape[1:])
+            point_feats = point_sample(
+                features, rel_img_points, align_corners=not self.aligned)
+            point_feats = point_feats.transpose(1, 2)
+        else:
+            point_feats = []
+            for batch_ind in range(num_imgs):
+                # unravel batch dim
+                feat = features[batch_ind].unsqueeze(0)
+                inds = (rois[:, 0].long() == batch_ind)
+                if inds.any():
+                    rel_img_points = rel_roi_point_to_rel_img_point(
+                        rois[inds], rel_roi_points[inds], feat,
+                        self.spatial_scale).unsqueeze(0)
+                    point_feat = point_sample(
+                        feat, rel_img_points, align_corners=not self.aligned)
+                    point_feat = point_feat.squeeze(0).transpose(0, 1)
+                    point_feats.append(point_feat)
+
+            point_feats = torch.cat(point_feats, dim=0)
+
+        channels = features.size(1)
+        roi_feats = point_feats.reshape(num_rois, channels, *self.output_size)
+
+        return roi_feats
+
+    def __repr__(self):
+        format_str = self.__class__.__name__
+        format_str += '(output_size={}, spatial_scale={}'.format(
+            self.output_size, self.spatial_scale)
+        return format_str
diff --git a/annotator/uniformer/mmcv/ops/points_in_boxes.py b/annotator/uniformer/mmcv/ops/points_in_boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..4003173a53052161dbcd687a2fa1d755642fdab8
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/points_in_boxes.py
@@ -0,0 +1,133 @@
+import torch
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+    'points_in_boxes_part_forward', 'points_in_boxes_cpu_forward',
+    'points_in_boxes_all_forward'
+])
+
+
+def points_in_boxes_part(points, boxes):
+    """Find the box in which each point is (CUDA).
+
+    Args:
+        points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+        boxes (torch.Tensor): [B, T, 7],
+            num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz] in
+            LiDAR/DEPTH coordinate, (x, y, z) is the bottom center
+
+    Returns:
+        box_idxs_of_pts (torch.Tensor): (B, M), default background = -1
+    """
+    assert points.shape[0] == boxes.shape[0], \
+        'Points and boxes should have the same batch size, ' \
+        f'but got {points.shape[0]} and {boxes.shape[0]}'
+    assert boxes.shape[2] == 7, \
+        'boxes dimension should be 7, ' \
+        f'but got unexpected shape {boxes.shape[2]}'
+    assert points.shape[2] == 3, \
+        'points dimension should be 3, ' \
+        f'but got unexpected shape {points.shape[2]}'
+    batch_size, num_points, _ = points.shape
+
+    box_idxs_of_pts = points.new_zeros((batch_size, num_points),
+                                       dtype=torch.int).fill_(-1)
+
+    # If manually put the tensor 'points' or 'boxes' on a device
+    # which is not the current device, some temporary variables
+    # will be created on the current device in the cuda op,
+    # and the output will be incorrect.
+    # Therefore, we force the current device to be the same
+    # as the device of the tensors if it was not.
+    # Please refer to https://github.com/open-mmlab/mmdetection3d/issues/305
+    # for the incorrect output before the fix.
+    points_device = points.get_device()
+    assert points_device == boxes.get_device(), \
+        'Points and boxes should be put on the same device'
+    if torch.cuda.current_device() != points_device:
+        torch.cuda.set_device(points_device)
+
+    ext_module.points_in_boxes_part_forward(boxes.contiguous(),
+                                            points.contiguous(),
+                                            box_idxs_of_pts)
+
+    return box_idxs_of_pts
+
+
+def points_in_boxes_cpu(points, boxes):
+    """Find all boxes in which each point is (CPU). The CPU version of
+    :meth:`points_in_boxes_all`.
+
+    Args:
+        points (torch.Tensor): [B, M, 3], [x, y, z] in
+            LiDAR/DEPTH coordinate
+        boxes (torch.Tensor): [B, T, 7],
+            num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+            (x, y, z) is the bottom center.
+
+    Returns:
+        box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+    """
+    assert points.shape[0] == boxes.shape[0], \
+        'Points and boxes should have the same batch size, ' \
+        f'but got {points.shape[0]} and {boxes.shape[0]}'
+    assert boxes.shape[2] == 7, \
+        'boxes dimension should be 7, ' \
+        f'but got unexpected shape {boxes.shape[2]}'
+    assert points.shape[2] == 3, \
+        'points dimension should be 3, ' \
+        f'but got unexpected shape {points.shape[2]}'
+    batch_size, num_points, _ = points.shape
+    num_boxes = boxes.shape[1]
+
+    point_indices = points.new_zeros((batch_size, num_boxes, num_points),
+                                     dtype=torch.int)
+    for b in range(batch_size):
+        ext_module.points_in_boxes_cpu_forward(boxes[b].float().contiguous(),
+                                               points[b].float().contiguous(),
+                                               point_indices[b])
+    point_indices = point_indices.transpose(1, 2)
+
+    return point_indices
+
+
+def points_in_boxes_all(points, boxes):
+    """Find all boxes in which each point is (CUDA).
+
+    Args:
+        points (torch.Tensor): [B, M, 3], [x, y, z] in LiDAR/DEPTH coordinate
+        boxes (torch.Tensor): [B, T, 7],
+            num_valid_boxes <= T, [x, y, z, x_size, y_size, z_size, rz],
+            (x, y, z) is the bottom center.
+
+    Returns:
+        box_idxs_of_pts (torch.Tensor): (B, M, T), default background = 0.
+    """
+    assert boxes.shape[0] == points.shape[0], \
+        'Points and boxes should have the same batch size, ' \
+        f'but got {boxes.shape[0]} and {boxes.shape[0]}'
+    assert boxes.shape[2] == 7, \
+        'boxes dimension should be 7, ' \
+        f'but got unexpected shape {boxes.shape[2]}'
+    assert points.shape[2] == 3, \
+        'points dimension should be 3, ' \
+        f'but got unexpected shape {points.shape[2]}'
+    batch_size, num_points, _ = points.shape
+    num_boxes = boxes.shape[1]
+
+    box_idxs_of_pts = points.new_zeros((batch_size, num_points, num_boxes),
+                                       dtype=torch.int).fill_(0)
+
+    # Same reason as line 25-32
+    points_device = points.get_device()
+    assert points_device == boxes.get_device(), \
+        'Points and boxes should be put on the same device'
+    if torch.cuda.current_device() != points_device:
+        torch.cuda.set_device(points_device)
+
+    ext_module.points_in_boxes_all_forward(boxes.contiguous(),
+                                           points.contiguous(),
+                                           box_idxs_of_pts)
+
+    return box_idxs_of_pts
diff --git a/annotator/uniformer/mmcv/ops/points_sampler.py b/annotator/uniformer/mmcv/ops/points_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a802a74fd6c3610d9ae178e6201f47423eca7ad1
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/points_sampler.py
@@ -0,0 +1,177 @@
+from typing import List
+
+import torch
+from torch import nn as nn
+
+from annotator.uniformer.mmcv.runner import force_fp32
+from .furthest_point_sample import (furthest_point_sample,
+                                    furthest_point_sample_with_dist)
+
+
+def calc_square_dist(point_feat_a, point_feat_b, norm=True):
+    """Calculating square distance between a and b.
+
+    Args:
+        point_feat_a (Tensor): (B, N, C) Feature vector of each point.
+        point_feat_b (Tensor): (B, M, C) Feature vector of each point.
+        norm (Bool, optional): Whether to normalize the distance.
+            Default: True.
+
+    Returns:
+        Tensor: (B, N, M) Distance between each pair points.
+    """
+    num_channel = point_feat_a.shape[-1]
+    # [bs, n, 1]
+    a_square = torch.sum(point_feat_a.unsqueeze(dim=2).pow(2), dim=-1)
+    # [bs, 1, m]
+    b_square = torch.sum(point_feat_b.unsqueeze(dim=1).pow(2), dim=-1)
+
+    corr_matrix = torch.matmul(point_feat_a, point_feat_b.transpose(1, 2))
+
+    dist = a_square + b_square - 2 * corr_matrix
+    if norm:
+        dist = torch.sqrt(dist) / num_channel
+    return dist
+
+
+def get_sampler_cls(sampler_type):
+    """Get the type and mode of points sampler.
+
+    Args:
+        sampler_type (str): The type of points sampler.
+            The valid value are "D-FPS", "F-FPS", or "FS".
+
+    Returns:
+        class: Points sampler type.
+    """
+    sampler_mappings = {
+        'D-FPS': DFPSSampler,
+        'F-FPS': FFPSSampler,
+        'FS': FSSampler,
+    }
+    try:
+        return sampler_mappings[sampler_type]
+    except KeyError:
+        raise KeyError(
+            f'Supported `sampler_type` are {sampler_mappings.keys()}, but got \
+                {sampler_type}')
+
+
+class PointsSampler(nn.Module):
+    """Points sampling.
+
+    Args:
+        num_point (list[int]): Number of sample points.
+        fps_mod_list (list[str], optional): Type of FPS method, valid mod
+            ['F-FPS', 'D-FPS', 'FS'], Default: ['D-FPS'].
+            F-FPS: using feature distances for FPS.
+            D-FPS: using Euclidean distances of points for FPS.
+            FS: using F-FPS and D-FPS simultaneously.
+        fps_sample_range_list (list[int], optional):
+            Range of points to apply FPS. Default: [-1].
+    """
+
+    def __init__(self,
+                 num_point: List[int],
+                 fps_mod_list: List[str] = ['D-FPS'],
+                 fps_sample_range_list: List[int] = [-1]):
+        super().__init__()
+        # FPS would be applied to different fps_mod in the list,
+        # so the length of the num_point should be equal to
+        # fps_mod_list and fps_sample_range_list.
+        assert len(num_point) == len(fps_mod_list) == len(
+            fps_sample_range_list)
+        self.num_point = num_point
+        self.fps_sample_range_list = fps_sample_range_list
+        self.samplers = nn.ModuleList()
+        for fps_mod in fps_mod_list:
+            self.samplers.append(get_sampler_cls(fps_mod)())
+        self.fp16_enabled = False
+
+    @force_fp32()
+    def forward(self, points_xyz, features):
+        """
+        Args:
+            points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
+            features (Tensor): (B, C, N) Descriptors of the features.
+
+        Returns:
+            Tensor: (B, npoint, sample_num) Indices of sampled points.
+        """
+        indices = []
+        last_fps_end_index = 0
+
+        for fps_sample_range, sampler, npoint in zip(
+                self.fps_sample_range_list, self.samplers, self.num_point):
+            assert fps_sample_range < points_xyz.shape[1]
+
+            if fps_sample_range == -1:
+                sample_points_xyz = points_xyz[:, last_fps_end_index:]
+                if features is not None:
+                    sample_features = features[:, :, last_fps_end_index:]
+                else:
+                    sample_features = None
+            else:
+                sample_points_xyz = \
+                    points_xyz[:, last_fps_end_index:fps_sample_range]
+                if features is not None:
+                    sample_features = features[:, :, last_fps_end_index:
+                                               fps_sample_range]
+                else:
+                    sample_features = None
+
+            fps_idx = sampler(sample_points_xyz.contiguous(), sample_features,
+                              npoint)
+
+            indices.append(fps_idx + last_fps_end_index)
+            last_fps_end_index += fps_sample_range
+        indices = torch.cat(indices, dim=1)
+
+        return indices
+
+
+class DFPSSampler(nn.Module):
+    """Using Euclidean distances of points for FPS."""
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, points, features, npoint):
+        """Sampling points with D-FPS."""
+        fps_idx = furthest_point_sample(points.contiguous(), npoint)
+        return fps_idx
+
+
+class FFPSSampler(nn.Module):
+    """Using feature distances for FPS."""
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, points, features, npoint):
+        """Sampling points with F-FPS."""
+        assert features is not None, \
+            'feature input to FFPS_Sampler should not be None'
+        features_for_fps = torch.cat([points, features.transpose(1, 2)], dim=2)
+        features_dist = calc_square_dist(
+            features_for_fps, features_for_fps, norm=False)
+        fps_idx = furthest_point_sample_with_dist(features_dist, npoint)
+        return fps_idx
+
+
+class FSSampler(nn.Module):
+    """Using F-FPS and D-FPS simultaneously."""
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, points, features, npoint):
+        """Sampling points with FS_Sampling."""
+        assert features is not None, \
+            'feature input to FS_Sampler should not be None'
+        ffps_sampler = FFPSSampler()
+        dfps_sampler = DFPSSampler()
+        fps_idx_ffps = ffps_sampler(points, features, npoint)
+        fps_idx_dfps = dfps_sampler(points, features, npoint)
+        fps_idx = torch.cat([fps_idx_ffps, fps_idx_dfps], dim=1)
+        return fps_idx
diff --git a/annotator/uniformer/mmcv/ops/psa_mask.py b/annotator/uniformer/mmcv/ops/psa_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf14e62b50e8d4dd6856c94333c703bcc4c9ab6
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/psa_mask.py
@@ -0,0 +1,92 @@
+# Modified from https://github.com/hszhao/semseg/blob/master/lib/psa
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+                                 ['psamask_forward', 'psamask_backward'])
+
+
+class PSAMaskFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input, psa_type, mask_size):
+        return g.op(
+            'mmcv::MMCVPSAMask',
+            input,
+            psa_type_i=psa_type,
+            mask_size_i=mask_size)
+
+    @staticmethod
+    def forward(ctx, input, psa_type, mask_size):
+        ctx.psa_type = psa_type
+        ctx.mask_size = _pair(mask_size)
+        ctx.save_for_backward(input)
+
+        h_mask, w_mask = ctx.mask_size
+        batch_size, channels, h_feature, w_feature = input.size()
+        assert channels == h_mask * w_mask
+        output = input.new_zeros(
+            (batch_size, h_feature * w_feature, h_feature, w_feature))
+
+        ext_module.psamask_forward(
+            input,
+            output,
+            psa_type=psa_type,
+            num_=batch_size,
+            h_feature=h_feature,
+            w_feature=w_feature,
+            h_mask=h_mask,
+            w_mask=w_mask,
+            half_h_mask=(h_mask - 1) // 2,
+            half_w_mask=(w_mask - 1) // 2)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input = ctx.saved_tensors[0]
+        psa_type = ctx.psa_type
+        h_mask, w_mask = ctx.mask_size
+        batch_size, channels, h_feature, w_feature = input.size()
+        grad_input = grad_output.new_zeros(
+            (batch_size, channels, h_feature, w_feature))
+        ext_module.psamask_backward(
+            grad_output,
+            grad_input,
+            psa_type=psa_type,
+            num_=batch_size,
+            h_feature=h_feature,
+            w_feature=w_feature,
+            h_mask=h_mask,
+            w_mask=w_mask,
+            half_h_mask=(h_mask - 1) // 2,
+            half_w_mask=(w_mask - 1) // 2)
+        return grad_input, None, None, None
+
+
+psa_mask = PSAMaskFunction.apply
+
+
+class PSAMask(nn.Module):
+
+    def __init__(self, psa_type, mask_size=None):
+        super(PSAMask, self).__init__()
+        assert psa_type in ['collect', 'distribute']
+        if psa_type == 'collect':
+            psa_type_enum = 0
+        else:
+            psa_type_enum = 1
+        self.psa_type_enum = psa_type_enum
+        self.mask_size = mask_size
+        self.psa_type = psa_type
+
+    def forward(self, input):
+        return psa_mask(input, self.psa_type_enum, self.mask_size)
+
+    def __repr__(self):
+        s = self.__class__.__name__
+        s += f'(psa_type={self.psa_type}, '
+        s += f'mask_size={self.mask_size})'
+        return s
diff --git a/annotator/uniformer/mmcv/ops/roi_align.py b/annotator/uniformer/mmcv/ops/roi_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..0755aefc66e67233ceae0f4b77948301c443e9fb
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roi_align.py
@@ -0,0 +1,223 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import deprecated_api_warning, ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+                                 ['roi_align_forward', 'roi_align_backward'])
+
+
+class RoIAlignFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input, rois, output_size, spatial_scale, sampling_ratio,
+                 pool_mode, aligned):
+        from ..onnx import is_custom_op_loaded
+        has_custom_op = is_custom_op_loaded()
+        if has_custom_op:
+            return g.op(
+                'mmcv::MMCVRoiAlign',
+                input,
+                rois,
+                output_height_i=output_size[0],
+                output_width_i=output_size[1],
+                spatial_scale_f=spatial_scale,
+                sampling_ratio_i=sampling_ratio,
+                mode_s=pool_mode,
+                aligned_i=aligned)
+        else:
+            from torch.onnx.symbolic_opset9 import sub, squeeze
+            from torch.onnx.symbolic_helper import _slice_helper
+            from torch.onnx import TensorProtoDataType
+            # batch_indices = rois[:, 0].long()
+            batch_indices = _slice_helper(
+                g, rois, axes=[1], starts=[0], ends=[1])
+            batch_indices = squeeze(g, batch_indices, 1)
+            batch_indices = g.op(
+                'Cast', batch_indices, to_i=TensorProtoDataType.INT64)
+            # rois = rois[:, 1:]
+            rois = _slice_helper(g, rois, axes=[1], starts=[1], ends=[5])
+            if aligned:
+                # rois -= 0.5/spatial_scale
+                aligned_offset = g.op(
+                    'Constant',
+                    value_t=torch.tensor([0.5 / spatial_scale],
+                                         dtype=torch.float32))
+                rois = sub(g, rois, aligned_offset)
+            # roi align
+            return g.op(
+                'RoiAlign',
+                input,
+                rois,
+                batch_indices,
+                output_height_i=output_size[0],
+                output_width_i=output_size[1],
+                spatial_scale_f=spatial_scale,
+                sampling_ratio_i=max(0, sampling_ratio),
+                mode_s=pool_mode)
+
+    @staticmethod
+    def forward(ctx,
+                input,
+                rois,
+                output_size,
+                spatial_scale=1.0,
+                sampling_ratio=0,
+                pool_mode='avg',
+                aligned=True):
+        ctx.output_size = _pair(output_size)
+        ctx.spatial_scale = spatial_scale
+        ctx.sampling_ratio = sampling_ratio
+        assert pool_mode in ('max', 'avg')
+        ctx.pool_mode = 0 if pool_mode == 'max' else 1
+        ctx.aligned = aligned
+        ctx.input_shape = input.size()
+
+        assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+        output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+                        ctx.output_size[1])
+        output = input.new_zeros(output_shape)
+        if ctx.pool_mode == 0:
+            argmax_y = input.new_zeros(output_shape)
+            argmax_x = input.new_zeros(output_shape)
+        else:
+            argmax_y = input.new_zeros(0)
+            argmax_x = input.new_zeros(0)
+
+        ext_module.roi_align_forward(
+            input,
+            rois,
+            output,
+            argmax_y,
+            argmax_x,
+            aligned_height=ctx.output_size[0],
+            aligned_width=ctx.output_size[1],
+            spatial_scale=ctx.spatial_scale,
+            sampling_ratio=ctx.sampling_ratio,
+            pool_mode=ctx.pool_mode,
+            aligned=ctx.aligned)
+
+        ctx.save_for_backward(rois, argmax_y, argmax_x)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        rois, argmax_y, argmax_x = ctx.saved_tensors
+        grad_input = grad_output.new_zeros(ctx.input_shape)
+        # complex head architecture may cause grad_output uncontiguous.
+        grad_output = grad_output.contiguous()
+        ext_module.roi_align_backward(
+            grad_output,
+            rois,
+            argmax_y,
+            argmax_x,
+            grad_input,
+            aligned_height=ctx.output_size[0],
+            aligned_width=ctx.output_size[1],
+            spatial_scale=ctx.spatial_scale,
+            sampling_ratio=ctx.sampling_ratio,
+            pool_mode=ctx.pool_mode,
+            aligned=ctx.aligned)
+        return grad_input, None, None, None, None, None, None
+
+
+roi_align = RoIAlignFunction.apply
+
+
+class RoIAlign(nn.Module):
+    """RoI align pooling layer.
+
+    Args:
+        output_size (tuple): h, w
+        spatial_scale (float): scale the input boxes by this number
+        sampling_ratio (int): number of inputs samples to take for each
+            output sample. 0 to take samples densely for current models.
+        pool_mode (str, 'avg' or 'max'): pooling mode in each bin.
+        aligned (bool): if False, use the legacy implementation in
+            MMDetection. If True, align the results more perfectly.
+        use_torchvision (bool): whether to use roi_align from torchvision.
+
+    Note:
+        The implementation of RoIAlign when aligned=True is modified from
+        https://github.com/facebookresearch/detectron2/
+
+        The meaning of aligned=True:
+
+        Given a continuous coordinate c, its two neighboring pixel
+        indices (in our pixel model) are computed by floor(c - 0.5) and
+        ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+        indices [0] and [1] (which are sampled from the underlying signal
+        at continuous coordinates 0.5 and 1.5). But the original roi_align
+        (aligned=False) does not subtract the 0.5 when computing
+        neighboring pixel indices and therefore it uses pixels with a
+        slightly incorrect alignment (relative to our pixel model) when
+        performing bilinear interpolation.
+
+        With `aligned=True`,
+        we first appropriately scale the ROI and then shift it by -0.5
+        prior to calling roi_align. This produces the correct neighbors;
+
+        The difference does not make a difference to the model's
+        performance if ROIAlign is used together with conv layers.
+    """
+
+    @deprecated_api_warning(
+        {
+            'out_size': 'output_size',
+            'sample_num': 'sampling_ratio'
+        },
+        cls_name='RoIAlign')
+    def __init__(self,
+                 output_size,
+                 spatial_scale=1.0,
+                 sampling_ratio=0,
+                 pool_mode='avg',
+                 aligned=True,
+                 use_torchvision=False):
+        super(RoIAlign, self).__init__()
+
+        self.output_size = _pair(output_size)
+        self.spatial_scale = float(spatial_scale)
+        self.sampling_ratio = int(sampling_ratio)
+        self.pool_mode = pool_mode
+        self.aligned = aligned
+        self.use_torchvision = use_torchvision
+
+    def forward(self, input, rois):
+        """
+        Args:
+            input: NCHW images
+            rois: Bx5 boxes. First column is the index into N.\
+                The other 4 columns are xyxy.
+        """
+        if self.use_torchvision:
+            from torchvision.ops import roi_align as tv_roi_align
+            if 'aligned' in tv_roi_align.__code__.co_varnames:
+                return tv_roi_align(input, rois, self.output_size,
+                                    self.spatial_scale, self.sampling_ratio,
+                                    self.aligned)
+            else:
+                if self.aligned:
+                    rois -= rois.new_tensor([0.] +
+                                            [0.5 / self.spatial_scale] * 4)
+                return tv_roi_align(input, rois, self.output_size,
+                                    self.spatial_scale, self.sampling_ratio)
+        else:
+            return roi_align(input, rois, self.output_size, self.spatial_scale,
+                             self.sampling_ratio, self.pool_mode, self.aligned)
+
+    def __repr__(self):
+        s = self.__class__.__name__
+        s += f'(output_size={self.output_size}, '
+        s += f'spatial_scale={self.spatial_scale}, '
+        s += f'sampling_ratio={self.sampling_ratio}, '
+        s += f'pool_mode={self.pool_mode}, '
+        s += f'aligned={self.aligned}, '
+        s += f'use_torchvision={self.use_torchvision})'
+        return s
diff --git a/annotator/uniformer/mmcv/ops/roi_align_rotated.py b/annotator/uniformer/mmcv/ops/roi_align_rotated.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ce4961a3555d4da8bc3e32f1f7d5ad50036587d
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roi_align_rotated.py
@@ -0,0 +1,177 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['roi_align_rotated_forward', 'roi_align_rotated_backward'])
+
+
+class RoIAlignRotatedFunction(Function):
+
+    @staticmethod
+    def symbolic(g, features, rois, out_size, spatial_scale, sample_num,
+                 aligned, clockwise):
+        if isinstance(out_size, int):
+            out_h = out_size
+            out_w = out_size
+        elif isinstance(out_size, tuple):
+            assert len(out_size) == 2
+            assert isinstance(out_size[0], int)
+            assert isinstance(out_size[1], int)
+            out_h, out_w = out_size
+        else:
+            raise TypeError(
+                '"out_size" must be an integer or tuple of integers')
+        return g.op(
+            'mmcv::MMCVRoIAlignRotated',
+            features,
+            rois,
+            output_height_i=out_h,
+            output_width_i=out_h,
+            spatial_scale_f=spatial_scale,
+            sampling_ratio_i=sample_num,
+            aligned_i=aligned,
+            clockwise_i=clockwise)
+
+    @staticmethod
+    def forward(ctx,
+                features,
+                rois,
+                out_size,
+                spatial_scale,
+                sample_num=0,
+                aligned=True,
+                clockwise=False):
+        if isinstance(out_size, int):
+            out_h = out_size
+            out_w = out_size
+        elif isinstance(out_size, tuple):
+            assert len(out_size) == 2
+            assert isinstance(out_size[0], int)
+            assert isinstance(out_size[1], int)
+            out_h, out_w = out_size
+        else:
+            raise TypeError(
+                '"out_size" must be an integer or tuple of integers')
+        ctx.spatial_scale = spatial_scale
+        ctx.sample_num = sample_num
+        ctx.aligned = aligned
+        ctx.clockwise = clockwise
+        ctx.save_for_backward(rois)
+        ctx.feature_size = features.size()
+
+        batch_size, num_channels, data_height, data_width = features.size()
+        num_rois = rois.size(0)
+
+        output = features.new_zeros(num_rois, num_channels, out_h, out_w)
+        ext_module.roi_align_rotated_forward(
+            features,
+            rois,
+            output,
+            pooled_height=out_h,
+            pooled_width=out_w,
+            spatial_scale=spatial_scale,
+            sample_num=sample_num,
+            aligned=aligned,
+            clockwise=clockwise)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        feature_size = ctx.feature_size
+        spatial_scale = ctx.spatial_scale
+        aligned = ctx.aligned
+        clockwise = ctx.clockwise
+        sample_num = ctx.sample_num
+        rois = ctx.saved_tensors[0]
+        assert feature_size is not None
+        batch_size, num_channels, data_height, data_width = feature_size
+
+        out_w = grad_output.size(3)
+        out_h = grad_output.size(2)
+
+        grad_input = grad_rois = None
+
+        if ctx.needs_input_grad[0]:
+            grad_input = rois.new_zeros(batch_size, num_channels, data_height,
+                                        data_width)
+            ext_module.roi_align_rotated_backward(
+                grad_output.contiguous(),
+                rois,
+                grad_input,
+                pooled_height=out_h,
+                pooled_width=out_w,
+                spatial_scale=spatial_scale,
+                sample_num=sample_num,
+                aligned=aligned,
+                clockwise=clockwise)
+        return grad_input, grad_rois, None, None, None, None, None
+
+
+roi_align_rotated = RoIAlignRotatedFunction.apply
+
+
+class RoIAlignRotated(nn.Module):
+    """RoI align pooling layer for rotated proposals.
+
+    It accepts a feature map of shape (N, C, H, W) and rois with shape
+    (n, 6) with each roi decoded as (batch_index, center_x, center_y,
+    w, h, angle). The angle is in radian.
+
+    Args:
+        out_size (tuple): h, w
+        spatial_scale (float): scale the input boxes by this number
+        sample_num (int): number of inputs samples to take for each
+            output sample. 0 to take samples densely for current models.
+        aligned (bool): if False, use the legacy implementation in
+            MMDetection. If True, align the results more perfectly.
+            Default: True.
+        clockwise (bool): If True, the angle in each proposal follows a
+            clockwise fashion in image space, otherwise, the angle is
+            counterclockwise. Default: False.
+
+    Note:
+        The implementation of RoIAlign when aligned=True is modified from
+        https://github.com/facebookresearch/detectron2/
+
+        The meaning of aligned=True:
+
+        Given a continuous coordinate c, its two neighboring pixel
+        indices (in our pixel model) are computed by floor(c - 0.5) and
+        ceil(c - 0.5). For example, c=1.3 has pixel neighbors with discrete
+        indices [0] and [1] (which are sampled from the underlying signal
+        at continuous coordinates 0.5 and 1.5). But the original roi_align
+        (aligned=False) does not subtract the 0.5 when computing
+        neighboring pixel indices and therefore it uses pixels with a
+        slightly incorrect alignment (relative to our pixel model) when
+        performing bilinear interpolation.
+
+        With `aligned=True`,
+        we first appropriately scale the ROI and then shift it by -0.5
+        prior to calling roi_align. This produces the correct neighbors;
+
+        The difference does not make a difference to the model's
+        performance if ROIAlign is used together with conv layers.
+    """
+
+    def __init__(self,
+                 out_size,
+                 spatial_scale,
+                 sample_num=0,
+                 aligned=True,
+                 clockwise=False):
+        super(RoIAlignRotated, self).__init__()
+
+        self.out_size = out_size
+        self.spatial_scale = float(spatial_scale)
+        self.sample_num = int(sample_num)
+        self.aligned = aligned
+        self.clockwise = clockwise
+
+    def forward(self, features, rois):
+        return RoIAlignRotatedFunction.apply(features, rois, self.out_size,
+                                             self.spatial_scale,
+                                             self.sample_num, self.aligned,
+                                             self.clockwise)
diff --git a/annotator/uniformer/mmcv/ops/roi_pool.py b/annotator/uniformer/mmcv/ops/roi_pool.py
new file mode 100644
index 0000000000000000000000000000000000000000..d339d8f2941eabc1cbe181a9c6c5ab5ff4ff4e5f
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roi_pool.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+                                 ['roi_pool_forward', 'roi_pool_backward'])
+
+
+class RoIPoolFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input, rois, output_size, spatial_scale):
+        return g.op(
+            'MaxRoiPool',
+            input,
+            rois,
+            pooled_shape_i=output_size,
+            spatial_scale_f=spatial_scale)
+
+    @staticmethod
+    def forward(ctx, input, rois, output_size, spatial_scale=1.0):
+        ctx.output_size = _pair(output_size)
+        ctx.spatial_scale = spatial_scale
+        ctx.input_shape = input.size()
+
+        assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
+
+        output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
+                        ctx.output_size[1])
+        output = input.new_zeros(output_shape)
+        argmax = input.new_zeros(output_shape, dtype=torch.int)
+
+        ext_module.roi_pool_forward(
+            input,
+            rois,
+            output,
+            argmax,
+            pooled_height=ctx.output_size[0],
+            pooled_width=ctx.output_size[1],
+            spatial_scale=ctx.spatial_scale)
+
+        ctx.save_for_backward(rois, argmax)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_output):
+        rois, argmax = ctx.saved_tensors
+        grad_input = grad_output.new_zeros(ctx.input_shape)
+
+        ext_module.roi_pool_backward(
+            grad_output,
+            rois,
+            argmax,
+            grad_input,
+            pooled_height=ctx.output_size[0],
+            pooled_width=ctx.output_size[1],
+            spatial_scale=ctx.spatial_scale)
+
+        return grad_input, None, None, None
+
+
+roi_pool = RoIPoolFunction.apply
+
+
+class RoIPool(nn.Module):
+
+    def __init__(self, output_size, spatial_scale=1.0):
+        super(RoIPool, self).__init__()
+
+        self.output_size = _pair(output_size)
+        self.spatial_scale = float(spatial_scale)
+
+    def forward(self, input, rois):
+        return roi_pool(input, rois, self.output_size, self.spatial_scale)
+
+    def __repr__(self):
+        s = self.__class__.__name__
+        s += f'(output_size={self.output_size}, '
+        s += f'spatial_scale={self.spatial_scale})'
+        return s
diff --git a/annotator/uniformer/mmcv/ops/roiaware_pool3d.py b/annotator/uniformer/mmcv/ops/roiaware_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..291b0e5a9b692492c7d7e495ea639c46042e2f18
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roiaware_pool3d.py
@@ -0,0 +1,114 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+
+import annotator.uniformer.mmcv as mmcv
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['roiaware_pool3d_forward', 'roiaware_pool3d_backward'])
+
+
+class RoIAwarePool3d(nn.Module):
+    """Encode the geometry-specific features of each 3D proposal.
+
+    Please refer to `PartA2 <https://arxiv.org/pdf/1907.03670.pdf>`_ for more
+    details.
+
+    Args:
+        out_size (int or tuple): The size of output features. n or
+            [n1, n2, n3].
+        max_pts_per_voxel (int, optional): The maximum number of points per
+            voxel. Default: 128.
+        mode (str, optional): Pooling method of RoIAware, 'max' or 'avg'.
+            Default: 'max'.
+    """
+
+    def __init__(self, out_size, max_pts_per_voxel=128, mode='max'):
+        super().__init__()
+
+        self.out_size = out_size
+        self.max_pts_per_voxel = max_pts_per_voxel
+        assert mode in ['max', 'avg']
+        pool_mapping = {'max': 0, 'avg': 1}
+        self.mode = pool_mapping[mode]
+
+    def forward(self, rois, pts, pts_feature):
+        """
+        Args:
+            rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+                (x, y, z) is the bottom center of rois.
+            pts (torch.Tensor): [npoints, 3], coordinates of input points.
+            pts_feature (torch.Tensor): [npoints, C], features of input points.
+
+        Returns:
+            pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C]
+        """
+
+        return RoIAwarePool3dFunction.apply(rois, pts, pts_feature,
+                                            self.out_size,
+                                            self.max_pts_per_voxel, self.mode)
+
+
+class RoIAwarePool3dFunction(Function):
+
+    @staticmethod
+    def forward(ctx, rois, pts, pts_feature, out_size, max_pts_per_voxel,
+                mode):
+        """
+        Args:
+            rois (torch.Tensor): [N, 7], in LiDAR coordinate,
+                (x, y, z) is the bottom center of rois.
+            pts (torch.Tensor): [npoints, 3], coordinates of input points.
+            pts_feature (torch.Tensor): [npoints, C], features of input points.
+            out_size (int or tuple): The size of output features. n or
+                [n1, n2, n3].
+            max_pts_per_voxel (int): The maximum number of points per voxel.
+                Default: 128.
+            mode (int): Pooling method of RoIAware, 0 (max pool) or 1 (average
+                pool).
+
+        Returns:
+            pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C], output
+                pooled features.
+        """
+
+        if isinstance(out_size, int):
+            out_x = out_y = out_z = out_size
+        else:
+            assert len(out_size) == 3
+            assert mmcv.is_tuple_of(out_size, int)
+            out_x, out_y, out_z = out_size
+
+        num_rois = rois.shape[0]
+        num_channels = pts_feature.shape[-1]
+        num_pts = pts.shape[0]
+
+        pooled_features = pts_feature.new_zeros(
+            (num_rois, out_x, out_y, out_z, num_channels))
+        argmax = pts_feature.new_zeros(
+            (num_rois, out_x, out_y, out_z, num_channels), dtype=torch.int)
+        pts_idx_of_voxels = pts_feature.new_zeros(
+            (num_rois, out_x, out_y, out_z, max_pts_per_voxel),
+            dtype=torch.int)
+
+        ext_module.roiaware_pool3d_forward(rois, pts, pts_feature, argmax,
+                                           pts_idx_of_voxels, pooled_features,
+                                           mode)
+
+        ctx.roiaware_pool3d_for_backward = (pts_idx_of_voxels, argmax, mode,
+                                            num_pts, num_channels)
+        return pooled_features
+
+    @staticmethod
+    def backward(ctx, grad_out):
+        ret = ctx.roiaware_pool3d_for_backward
+        pts_idx_of_voxels, argmax, mode, num_pts, num_channels = ret
+
+        grad_in = grad_out.new_zeros((num_pts, num_channels))
+        ext_module.roiaware_pool3d_backward(pts_idx_of_voxels, argmax,
+                                            grad_out.contiguous(), grad_in,
+                                            mode)
+
+        return None, None, grad_in, None, None, None
diff --git a/annotator/uniformer/mmcv/ops/roipoint_pool3d.py b/annotator/uniformer/mmcv/ops/roipoint_pool3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a21412c0728431c04b84245bc2e3109eea9aefc
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/roipoint_pool3d.py
@@ -0,0 +1,77 @@
+from torch import nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward'])
+
+
+class RoIPointPool3d(nn.Module):
+    """Encode the geometry-specific features of each 3D proposal.
+
+    Please refer to `Paper of PartA2 <https://arxiv.org/pdf/1907.03670.pdf>`_
+    for more details.
+
+    Args:
+        num_sampled_points (int, optional): Number of samples in each roi.
+            Default: 512.
+    """
+
+    def __init__(self, num_sampled_points=512):
+        super().__init__()
+        self.num_sampled_points = num_sampled_points
+
+    def forward(self, points, point_features, boxes3d):
+        """
+        Args:
+            points (torch.Tensor): Input points whose shape is (B, N, C).
+            point_features (torch.Tensor): Features of input points whose shape
+                is (B, N, C).
+            boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+
+        Returns:
+            pooled_features (torch.Tensor): The output pooled features whose
+                shape is (B, M, 512, 3 + C).
+            pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+        """
+        return RoIPointPool3dFunction.apply(points, point_features, boxes3d,
+                                            self.num_sampled_points)
+
+
+class RoIPointPool3dFunction(Function):
+
+    @staticmethod
+    def forward(ctx, points, point_features, boxes3d, num_sampled_points=512):
+        """
+        Args:
+            points (torch.Tensor): Input points whose shape is (B, N, C).
+            point_features (torch.Tensor): Features of input points whose shape
+                is (B, N, C).
+            boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7).
+            num_sampled_points (int, optional): The num of sampled points.
+                Default: 512.
+
+        Returns:
+            pooled_features (torch.Tensor): The output pooled features whose
+                shape is (B, M, 512, 3 + C).
+            pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M).
+        """
+        assert len(points.shape) == 3 and points.shape[2] == 3
+        batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[
+            1], point_features.shape[2]
+        pooled_boxes3d = boxes3d.view(batch_size, -1, 7)
+        pooled_features = point_features.new_zeros(
+            (batch_size, boxes_num, num_sampled_points, 3 + feature_len))
+        pooled_empty_flag = point_features.new_zeros(
+            (batch_size, boxes_num)).int()
+
+        ext_module.roipoint_pool3d_forward(points.contiguous(),
+                                           pooled_boxes3d.contiguous(),
+                                           point_features.contiguous(),
+                                           pooled_features, pooled_empty_flag)
+
+        return pooled_features, pooled_empty_flag
+
+    @staticmethod
+    def backward(ctx, grad_out):
+        raise NotImplementedError
diff --git a/annotator/uniformer/mmcv/ops/saconv.py b/annotator/uniformer/mmcv/ops/saconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4ee3978e097fca422805db4e31ae481006d7971
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/saconv.py
@@ -0,0 +1,145 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.uniformer.mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
+from annotator.uniformer.mmcv.ops.deform_conv import deform_conv2d
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+
+
+@CONV_LAYERS.register_module(name='SAC')
+class SAConv2d(ConvAWS2d):
+    """SAC (Switchable Atrous Convolution)
+
+    This is an implementation of SAC in DetectoRS
+    (https://arxiv.org/pdf/2006.02334.pdf).
+
+    Args:
+        in_channels (int): Number of channels in the input image
+        out_channels (int): Number of channels produced by the convolution
+        kernel_size (int or tuple): Size of the convolving kernel
+        stride (int or tuple, optional): Stride of the convolution. Default: 1
+        padding (int or tuple, optional): Zero-padding added to both sides of
+            the input. Default: 0
+        padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
+            ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
+        dilation (int or tuple, optional): Spacing between kernel elements.
+            Default: 1
+        groups (int, optional): Number of blocked connections from input
+            channels to output channels. Default: 1
+        bias (bool, optional): If ``True``, adds a learnable bias to the
+            output. Default: ``True``
+        use_deform: If ``True``, replace convolution with deformable
+            convolution. Default: ``False``.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 bias=True,
+                 use_deform=False):
+        super().__init__(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+            bias=bias)
+        self.use_deform = use_deform
+        self.switch = nn.Conv2d(
+            self.in_channels, 1, kernel_size=1, stride=stride, bias=True)
+        self.weight_diff = nn.Parameter(torch.Tensor(self.weight.size()))
+        self.pre_context = nn.Conv2d(
+            self.in_channels, self.in_channels, kernel_size=1, bias=True)
+        self.post_context = nn.Conv2d(
+            self.out_channels, self.out_channels, kernel_size=1, bias=True)
+        if self.use_deform:
+            self.offset_s = nn.Conv2d(
+                self.in_channels,
+                18,
+                kernel_size=3,
+                padding=1,
+                stride=stride,
+                bias=True)
+            self.offset_l = nn.Conv2d(
+                self.in_channels,
+                18,
+                kernel_size=3,
+                padding=1,
+                stride=stride,
+                bias=True)
+        self.init_weights()
+
+    def init_weights(self):
+        constant_init(self.switch, 0, bias=1)
+        self.weight_diff.data.zero_()
+        constant_init(self.pre_context, 0)
+        constant_init(self.post_context, 0)
+        if self.use_deform:
+            constant_init(self.offset_s, 0)
+            constant_init(self.offset_l, 0)
+
+    def forward(self, x):
+        # pre-context
+        avg_x = F.adaptive_avg_pool2d(x, output_size=1)
+        avg_x = self.pre_context(avg_x)
+        avg_x = avg_x.expand_as(x)
+        x = x + avg_x
+        # switch
+        avg_x = F.pad(x, pad=(2, 2, 2, 2), mode='reflect')
+        avg_x = F.avg_pool2d(avg_x, kernel_size=5, stride=1, padding=0)
+        switch = self.switch(avg_x)
+        # sac
+        weight = self._get_weight(self.weight)
+        zero_bias = torch.zeros(
+            self.out_channels, device=weight.device, dtype=weight.dtype)
+
+        if self.use_deform:
+            offset = self.offset_s(avg_x)
+            out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
+                                  self.dilation, self.groups, 1)
+        else:
+            if (TORCH_VERSION == 'parrots'
+                    or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+                out_s = super().conv2d_forward(x, weight)
+            elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+                # bias is a required argument of _conv_forward in torch 1.8.0
+                out_s = super()._conv_forward(x, weight, zero_bias)
+            else:
+                out_s = super()._conv_forward(x, weight)
+        ori_p = self.padding
+        ori_d = self.dilation
+        self.padding = tuple(3 * p for p in self.padding)
+        self.dilation = tuple(3 * d for d in self.dilation)
+        weight = weight + self.weight_diff
+        if self.use_deform:
+            offset = self.offset_l(avg_x)
+            out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
+                                  self.dilation, self.groups, 1)
+        else:
+            if (TORCH_VERSION == 'parrots'
+                    or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
+                out_l = super().conv2d_forward(x, weight)
+            elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
+                # bias is a required argument of _conv_forward in torch 1.8.0
+                out_l = super()._conv_forward(x, weight, zero_bias)
+            else:
+                out_l = super()._conv_forward(x, weight)
+
+        out = switch * out_s + (1 - switch) * out_l
+        self.padding = ori_p
+        self.dilation = ori_d
+        # post-context
+        avg_x = F.adaptive_avg_pool2d(out, output_size=1)
+        avg_x = self.post_context(avg_x)
+        avg_x = avg_x.expand_as(out)
+        out = out + avg_x
+        return out
diff --git a/annotator/uniformer/mmcv/ops/scatter_points.py b/annotator/uniformer/mmcv/ops/scatter_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b8aa4169e9f6ca4a6f845ce17d6d1e4db416bb8
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/scatter_points.py
@@ -0,0 +1,135 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext',
+    ['dynamic_point_to_voxel_forward', 'dynamic_point_to_voxel_backward'])
+
+
+class _DynamicScatter(Function):
+
+    @staticmethod
+    def forward(ctx, feats, coors, reduce_type='max'):
+        """convert kitti points(N, >=3) to voxels.
+
+        Args:
+            feats (torch.Tensor): [N, C]. Points features to be reduced
+                into voxels.
+            coors (torch.Tensor): [N, ndim]. Corresponding voxel coordinates
+                (specifically multi-dim voxel index) of each points.
+            reduce_type (str, optional): Reduce op. support 'max', 'sum' and
+                'mean'. Default: 'max'.
+
+        Returns:
+            voxel_feats (torch.Tensor): [M, C]. Reduced features, input
+                features that shares the same voxel coordinates are reduced to
+                one row.
+            voxel_coors (torch.Tensor): [M, ndim]. Voxel coordinates.
+        """
+        results = ext_module.dynamic_point_to_voxel_forward(
+            feats, coors, reduce_type)
+        (voxel_feats, voxel_coors, point2voxel_map,
+         voxel_points_count) = results
+        ctx.reduce_type = reduce_type
+        ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
+                              voxel_points_count)
+        ctx.mark_non_differentiable(voxel_coors)
+        return voxel_feats, voxel_coors
+
+    @staticmethod
+    def backward(ctx, grad_voxel_feats, grad_voxel_coors=None):
+        (feats, voxel_feats, point2voxel_map,
+         voxel_points_count) = ctx.saved_tensors
+        grad_feats = torch.zeros_like(feats)
+        # TODO: whether to use index put or use cuda_backward
+        # To use index put, need point to voxel index
+        ext_module.dynamic_point_to_voxel_backward(
+            grad_feats, grad_voxel_feats.contiguous(), feats, voxel_feats,
+            point2voxel_map, voxel_points_count, ctx.reduce_type)
+        return grad_feats, None, None
+
+
+dynamic_scatter = _DynamicScatter.apply
+
+
+class DynamicScatter(nn.Module):
+    """Scatters points into voxels, used in the voxel encoder with dynamic
+    voxelization.
+
+    Note:
+        The CPU and GPU implementation get the same output, but have numerical
+        difference after summation and division (e.g., 5e-7).
+
+    Args:
+        voxel_size (list): list [x, y, z] size of three dimension.
+        point_cloud_range (list): The coordinate range of points, [x_min,
+            y_min, z_min, x_max, y_max, z_max].
+        average_points (bool): whether to use avg pooling to scatter points
+            into voxel.
+    """
+
+    def __init__(self, voxel_size, point_cloud_range, average_points: bool):
+        super().__init__()
+
+        self.voxel_size = voxel_size
+        self.point_cloud_range = point_cloud_range
+        self.average_points = average_points
+
+    def forward_single(self, points, coors):
+        """Scatters points into voxels.
+
+        Args:
+            points (torch.Tensor): Points to be reduced into voxels.
+            coors (torch.Tensor): Corresponding voxel coordinates (specifically
+                multi-dim voxel index) of each points.
+
+        Returns:
+            voxel_feats (torch.Tensor): Reduced features, input features that
+                shares the same voxel coordinates are reduced to one row.
+            voxel_coors (torch.Tensor): Voxel coordinates.
+        """
+        reduce = 'mean' if self.average_points else 'max'
+        return dynamic_scatter(points.contiguous(), coors.contiguous(), reduce)
+
+    def forward(self, points, coors):
+        """Scatters points/features into voxels.
+
+        Args:
+            points (torch.Tensor): Points to be reduced into voxels.
+            coors (torch.Tensor): Corresponding voxel coordinates (specifically
+                multi-dim voxel index) of each points.
+
+        Returns:
+            voxel_feats (torch.Tensor): Reduced features, input features that
+                shares the same voxel coordinates are reduced to one row.
+            voxel_coors (torch.Tensor): Voxel coordinates.
+        """
+        if coors.size(-1) == 3:
+            return self.forward_single(points, coors)
+        else:
+            batch_size = coors[-1, 0] + 1
+            voxels, voxel_coors = [], []
+            for i in range(batch_size):
+                inds = torch.where(coors[:, 0] == i)
+                voxel, voxel_coor = self.forward_single(
+                    points[inds], coors[inds][:, 1:])
+                coor_pad = nn.functional.pad(
+                    voxel_coor, (1, 0), mode='constant', value=i)
+                voxel_coors.append(coor_pad)
+                voxels.append(voxel)
+            features = torch.cat(voxels, dim=0)
+            feature_coors = torch.cat(voxel_coors, dim=0)
+
+            return features, feature_coors
+
+    def __repr__(self):
+        s = self.__class__.__name__ + '('
+        s += 'voxel_size=' + str(self.voxel_size)
+        s += ', point_cloud_range=' + str(self.point_cloud_range)
+        s += ', average_points=' + str(self.average_points)
+        s += ')'
+        return s
diff --git a/annotator/uniformer/mmcv/ops/sync_bn.py b/annotator/uniformer/mmcv/ops/sync_bn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b016fcbe860989c56cd1040034bcfa60e146d2
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/sync_bn.py
@@ -0,0 +1,279 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.module import Module
+from torch.nn.parameter import Parameter
+
+from annotator.uniformer.mmcv.cnn import NORM_LAYERS
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', [
+    'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output',
+    'sync_bn_backward_param', 'sync_bn_backward_data'
+])
+
+
+class SyncBatchNormFunction(Function):
+
+    @staticmethod
+    def symbolic(g, input, running_mean, running_var, weight, bias, momentum,
+                 eps, group, group_size, stats_mode):
+        return g.op(
+            'mmcv::MMCVSyncBatchNorm',
+            input,
+            running_mean,
+            running_var,
+            weight,
+            bias,
+            momentum_f=momentum,
+            eps_f=eps,
+            group_i=group,
+            group_size_i=group_size,
+            stats_mode=stats_mode)
+
+    @staticmethod
+    def forward(self, input, running_mean, running_var, weight, bias, momentum,
+                eps, group, group_size, stats_mode):
+        self.momentum = momentum
+        self.eps = eps
+        self.group = group
+        self.group_size = group_size
+        self.stats_mode = stats_mode
+
+        assert isinstance(
+                   input, (torch.HalfTensor, torch.FloatTensor,
+                           torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \
+               f'only support Half or Float Tensor, but {input.type()}'
+        output = torch.zeros_like(input)
+        input3d = input.flatten(start_dim=2)
+        output3d = output.view_as(input3d)
+        num_channels = input3d.size(1)
+
+        # ensure mean/var/norm/std are initialized as zeros
+        # ``torch.empty()`` does not guarantee that
+        mean = torch.zeros(
+            num_channels, dtype=torch.float, device=input3d.device)
+        var = torch.zeros(
+            num_channels, dtype=torch.float, device=input3d.device)
+        norm = torch.zeros_like(
+            input3d, dtype=torch.float, device=input3d.device)
+        std = torch.zeros(
+            num_channels, dtype=torch.float, device=input3d.device)
+
+        batch_size = input3d.size(0)
+        if batch_size > 0:
+            ext_module.sync_bn_forward_mean(input3d, mean)
+            batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype)
+        else:
+            # skip updating mean and leave it as zeros when the input is empty
+            batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype)
+
+        # synchronize mean and the batch flag
+        vec = torch.cat([mean, batch_flag])
+        if self.stats_mode == 'N':
+            vec *= batch_size
+        if self.group_size > 1:
+            dist.all_reduce(vec, group=self.group)
+        total_batch = vec[-1].detach()
+        mean = vec[:num_channels]
+
+        if self.stats_mode == 'default':
+            mean = mean / self.group_size
+        elif self.stats_mode == 'N':
+            mean = mean / total_batch.clamp(min=1)
+        else:
+            raise NotImplementedError
+
+        # leave var as zeros when the input is empty
+        if batch_size > 0:
+            ext_module.sync_bn_forward_var(input3d, mean, var)
+
+        if self.stats_mode == 'N':
+            var *= batch_size
+        if self.group_size > 1:
+            dist.all_reduce(var, group=self.group)
+
+        if self.stats_mode == 'default':
+            var /= self.group_size
+        elif self.stats_mode == 'N':
+            var /= total_batch.clamp(min=1)
+        else:
+            raise NotImplementedError
+
+        # if the total batch size over all the ranks is zero,
+        # we should not update the statistics in the current batch
+        update_flag = total_batch.clamp(max=1)
+        momentum = update_flag * self.momentum
+        ext_module.sync_bn_forward_output(
+            input3d,
+            mean,
+            var,
+            weight,
+            bias,
+            running_mean,
+            running_var,
+            norm,
+            std,
+            output3d,
+            eps=self.eps,
+            momentum=momentum,
+            group_size=self.group_size)
+        self.save_for_backward(norm, std, weight)
+        return output
+
+    @staticmethod
+    @once_differentiable
+    def backward(self, grad_output):
+        norm, std, weight = self.saved_tensors
+        grad_weight = torch.zeros_like(weight)
+        grad_bias = torch.zeros_like(weight)
+        grad_input = torch.zeros_like(grad_output)
+        grad_output3d = grad_output.flatten(start_dim=2)
+        grad_input3d = grad_input.view_as(grad_output3d)
+
+        batch_size = grad_input3d.size(0)
+        if batch_size > 0:
+            ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight,
+                                              grad_bias)
+
+        # all reduce
+        if self.group_size > 1:
+            dist.all_reduce(grad_weight, group=self.group)
+            dist.all_reduce(grad_bias, group=self.group)
+            grad_weight /= self.group_size
+            grad_bias /= self.group_size
+
+        if batch_size > 0:
+            ext_module.sync_bn_backward_data(grad_output3d, weight,
+                                             grad_weight, grad_bias, norm, std,
+                                             grad_input3d)
+
+        return grad_input, None, None, grad_weight, grad_bias, \
+            None, None, None, None, None
+
+
+@NORM_LAYERS.register_module(name='MMSyncBN')
+class SyncBatchNorm(Module):
+    """Synchronized Batch Normalization.
+
+    Args:
+        num_features (int): number of features/chennels in input tensor
+        eps (float, optional): a value added to the denominator for numerical
+            stability. Defaults to 1e-5.
+        momentum (float, optional): the value used for the running_mean and
+            running_var computation. Defaults to 0.1.
+        affine (bool, optional): whether to use learnable affine parameters.
+            Defaults to True.
+        track_running_stats (bool, optional): whether to track the running
+            mean and variance during training. When set to False, this
+            module does not track such statistics, and initializes statistics
+            buffers ``running_mean`` and ``running_var`` as ``None``. When
+            these buffers are ``None``, this module always uses batch
+            statistics in both training and eval modes. Defaults to True.
+        group (int, optional): synchronization of stats happen within
+            each process group individually. By default it is synchronization
+            across the whole world. Defaults to None.
+        stats_mode (str, optional): The statistical mode. Available options
+            includes ``'default'`` and ``'N'``. Defaults to 'default'.
+            When ``stats_mode=='default'``, it computes the overall statistics
+            using those from each worker with equal weight, i.e., the
+            statistics are synchronized and simply divied by ``group``. This
+            mode will produce inaccurate statistics when empty tensors occur.
+            When ``stats_mode=='N'``, it compute the overall statistics using
+            the total number of batches in each worker ignoring the number of
+            group, i.e., the statistics are synchronized and then divied by
+            the total batch ``N``. This mode is beneficial when empty tensors
+            occur during training, as it average the total mean by the real
+            number of batch.
+    """
+
+    def __init__(self,
+                 num_features,
+                 eps=1e-5,
+                 momentum=0.1,
+                 affine=True,
+                 track_running_stats=True,
+                 group=None,
+                 stats_mode='default'):
+        super(SyncBatchNorm, self).__init__()
+        self.num_features = num_features
+        self.eps = eps
+        self.momentum = momentum
+        self.affine = affine
+        self.track_running_stats = track_running_stats
+        group = dist.group.WORLD if group is None else group
+        self.group = group
+        self.group_size = dist.get_world_size(group)
+        assert stats_mode in ['default', 'N'], \
+            f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"'
+        self.stats_mode = stats_mode
+        if self.affine:
+            self.weight = Parameter(torch.Tensor(num_features))
+            self.bias = Parameter(torch.Tensor(num_features))
+        else:
+            self.register_parameter('weight', None)
+            self.register_parameter('bias', None)
+        if self.track_running_stats:
+            self.register_buffer('running_mean', torch.zeros(num_features))
+            self.register_buffer('running_var', torch.ones(num_features))
+            self.register_buffer('num_batches_tracked',
+                                 torch.tensor(0, dtype=torch.long))
+        else:
+            self.register_buffer('running_mean', None)
+            self.register_buffer('running_var', None)
+            self.register_buffer('num_batches_tracked', None)
+        self.reset_parameters()
+
+    def reset_running_stats(self):
+        if self.track_running_stats:
+            self.running_mean.zero_()
+            self.running_var.fill_(1)
+            self.num_batches_tracked.zero_()
+
+    def reset_parameters(self):
+        self.reset_running_stats()
+        if self.affine:
+            self.weight.data.uniform_()  # pytorch use ones_()
+            self.bias.data.zero_()
+
+    def forward(self, input):
+        if input.dim() < 2:
+            raise ValueError(
+                f'expected at least 2D input, got {input.dim()}D input')
+        if self.momentum is None:
+            exponential_average_factor = 0.0
+        else:
+            exponential_average_factor = self.momentum
+
+        if self.training and self.track_running_stats:
+            if self.num_batches_tracked is not None:
+                self.num_batches_tracked += 1
+                if self.momentum is None:  # use cumulative moving average
+                    exponential_average_factor = 1.0 / float(
+                        self.num_batches_tracked)
+                else:  # use exponential moving average
+                    exponential_average_factor = self.momentum
+
+        if self.training or not self.track_running_stats:
+            return SyncBatchNormFunction.apply(
+                input, self.running_mean, self.running_var, self.weight,
+                self.bias, exponential_average_factor, self.eps, self.group,
+                self.group_size, self.stats_mode)
+        else:
+            return F.batch_norm(input, self.running_mean, self.running_var,
+                                self.weight, self.bias, False,
+                                exponential_average_factor, self.eps)
+
+    def __repr__(self):
+        s = self.__class__.__name__
+        s += f'({self.num_features}, '
+        s += f'eps={self.eps}, '
+        s += f'momentum={self.momentum}, '
+        s += f'affine={self.affine}, '
+        s += f'track_running_stats={self.track_running_stats}, '
+        s += f'group_size={self.group_size},'
+        s += f'stats_mode={self.stats_mode})'
+        return s
diff --git a/annotator/uniformer/mmcv/ops/three_interpolate.py b/annotator/uniformer/mmcv/ops/three_interpolate.py
new file mode 100644
index 0000000000000000000000000000000000000000..203f47f05d58087e034fb3cd8cd6a09233947b4a
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/three_interpolate.py
@@ -0,0 +1,68 @@
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['three_interpolate_forward', 'three_interpolate_backward'])
+
+
+class ThreeInterpolate(Function):
+    """Performs weighted linear interpolation on 3 features.
+
+    Please refer to `Paper of PointNet++ <https://arxiv.org/abs/1706.02413>`_
+    for more details.
+    """
+
+    @staticmethod
+    def forward(ctx, features: torch.Tensor, indices: torch.Tensor,
+                weight: torch.Tensor) -> torch.Tensor:
+        """
+        Args:
+            features (Tensor): (B, C, M) Features descriptors to be
+                interpolated
+            indices (Tensor): (B, n, 3) index three nearest neighbors
+                of the target features in features
+            weight (Tensor): (B, n, 3) weights of interpolation
+
+        Returns:
+            Tensor: (B, C, N) tensor of the interpolated features
+        """
+        assert features.is_contiguous()
+        assert indices.is_contiguous()
+        assert weight.is_contiguous()
+
+        B, c, m = features.size()
+        n = indices.size(1)
+        ctx.three_interpolate_for_backward = (indices, weight, m)
+        output = torch.cuda.FloatTensor(B, c, n)
+
+        ext_module.three_interpolate_forward(
+            features, indices, weight, output, b=B, c=c, m=m, n=n)
+        return output
+
+    @staticmethod
+    def backward(
+        ctx, grad_out: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Args:
+            grad_out (Tensor): (B, C, N) tensor with gradients of outputs
+
+        Returns:
+            Tensor: (B, C, M) tensor with gradients of features
+        """
+        idx, weight, m = ctx.three_interpolate_for_backward
+        B, c, n = grad_out.size()
+
+        grad_features = torch.cuda.FloatTensor(B, c, m).zero_()
+        grad_out_data = grad_out.data.contiguous()
+
+        ext_module.three_interpolate_backward(
+            grad_out_data, idx, weight, grad_features.data, b=B, c=c, n=n, m=m)
+        return grad_features, None, None
+
+
+three_interpolate = ThreeInterpolate.apply
diff --git a/annotator/uniformer/mmcv/ops/three_nn.py b/annotator/uniformer/mmcv/ops/three_nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b01047a129989cd5545a0a86f23a487f4a13ce1
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/three_nn.py
@@ -0,0 +1,51 @@
+from typing import Tuple
+
+import torch
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext', ['three_nn_forward'])
+
+
+class ThreeNN(Function):
+    """Find the top-3 nearest neighbors of the target set from the source set.
+
+    Please refer to `Paper of PointNet++ <https://arxiv.org/abs/1706.02413>`_
+    for more details.
+    """
+
+    @staticmethod
+    def forward(ctx, target: torch.Tensor,
+                source: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Args:
+            target (Tensor): shape (B, N, 3), points set that needs to
+                find the nearest neighbors.
+            source (Tensor): shape (B, M, 3), points set that is used
+                to find the nearest neighbors of points in target set.
+
+        Returns:
+            Tensor: shape (B, N, 3), L2 distance of each point in target
+                set to their corresponding nearest neighbors.
+        """
+        target = target.contiguous()
+        source = source.contiguous()
+
+        B, N, _ = target.size()
+        m = source.size(1)
+        dist2 = torch.cuda.FloatTensor(B, N, 3)
+        idx = torch.cuda.IntTensor(B, N, 3)
+
+        ext_module.three_nn_forward(target, source, dist2, idx, b=B, n=N, m=m)
+        if torch.__version__ != 'parrots':
+            ctx.mark_non_differentiable(idx)
+
+        return torch.sqrt(dist2), idx
+
+    @staticmethod
+    def backward(ctx, a=None, b=None):
+        return None, None
+
+
+three_nn = ThreeNN.apply
diff --git a/annotator/uniformer/mmcv/ops/tin_shift.py b/annotator/uniformer/mmcv/ops/tin_shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..472c9fcfe45a124e819b7ed5653e585f94a8811e
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/tin_shift.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Code reference from "Temporal Interlacing Network"
+# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py
+# Hao Shao, Shengju Qian, Yu Liu
+# shaoh19@mails.tsinghua.edu.cn, sjqian@cse.cuhk.edu.hk, yuliu@ee.cuhk.edu.hk
+
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext('_ext',
+                                 ['tin_shift_forward', 'tin_shift_backward'])
+
+
+class TINShiftFunction(Function):
+
+    @staticmethod
+    def forward(ctx, input, shift):
+        C = input.size(2)
+        num_segments = shift.size(1)
+        if C // num_segments <= 0 or C % num_segments != 0:
+            raise ValueError('C should be a multiple of num_segments, '
+                             f'but got C={C} and num_segments={num_segments}.')
+
+        ctx.save_for_backward(shift)
+
+        out = torch.zeros_like(input)
+        ext_module.tin_shift_forward(input, shift, out)
+
+        return out
+
+    @staticmethod
+    def backward(ctx, grad_output):
+
+        shift = ctx.saved_tensors[0]
+        data_grad_input = grad_output.new(*grad_output.size()).zero_()
+        shift_grad_input = shift.new(*shift.size()).zero_()
+        ext_module.tin_shift_backward(grad_output, shift, data_grad_input)
+
+        return data_grad_input, shift_grad_input
+
+
+tin_shift = TINShiftFunction.apply
+
+
+class TINShift(nn.Module):
+    """Temporal Interlace Shift.
+
+    Temporal Interlace shift is a differentiable temporal-wise frame shifting
+    which is proposed in "Temporal Interlacing Network"
+
+    Please refer to https://arxiv.org/abs/2001.06499 for more details.
+    Code is modified from https://github.com/mit-han-lab/temporal-shift-module
+    """
+
+    def forward(self, input, shift):
+        """Perform temporal interlace shift.
+
+        Args:
+            input (Tensor): Feature map with shape [N, num_segments, C, H * W].
+            shift (Tensor): Shift tensor with shape [N, num_segments].
+
+        Returns:
+            Feature map after temporal interlace shift.
+        """
+        return tin_shift(input, shift)
diff --git a/annotator/uniformer/mmcv/ops/upfirdn2d.py b/annotator/uniformer/mmcv/ops/upfirdn2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8bb2c3c949eed38a6465ed369fa881538dca010
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/upfirdn2d.py
@@ -0,0 +1,330 @@
+# modified from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py  # noqa:E501
+
+# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
+# NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator
+# Augmentation (ADA)
+# =======================================================================
+
+# 1. Definitions
+
+# "Licensor" means any person or entity that distributes its Work.
+
+# "Software" means the original work of authorship made available under
+# this License.
+
+# "Work" means the Software and any additions to or derivative works of
+# the Software that are made available under this License.
+
+# The terms "reproduce," "reproduction," "derivative works," and
+# "distribution" have the meaning as provided under U.S. copyright law;
+# provided, however, that for the purposes of this License, derivative
+# works shall not include works that remain separable from, or merely
+# link (or bind by name) to the interfaces of, the Work.
+
+# Works, including the Software, are "made available" under this License
+# by including in or with the Work either (a) a copyright notice
+# referencing the applicability of this License to the Work, or (b) a
+# copy of this License.
+
+# 2. License Grants
+
+#     2.1 Copyright Grant. Subject to the terms and conditions of this
+#     License, each Licensor grants to you a perpetual, worldwide,
+#     non-exclusive, royalty-free, copyright license to reproduce,
+#     prepare derivative works of, publicly display, publicly perform,
+#     sublicense and distribute its Work and any resulting derivative
+#     works in any form.
+
+# 3. Limitations
+
+#     3.1 Redistribution. You may reproduce or distribute the Work only
+#     if (a) you do so under this License, (b) you include a complete
+#     copy of this License with your distribution, and (c) you retain
+#     without modification any copyright, patent, trademark, or
+#     attribution notices that are present in the Work.
+
+#     3.2 Derivative Works. You may specify that additional or different
+#     terms apply to the use, reproduction, and distribution of your
+#     derivative works of the Work ("Your Terms") only if (a) Your Terms
+#     provide that the use limitation in Section 3.3 applies to your
+#     derivative works, and (b) you identify the specific derivative
+#     works that are subject to Your Terms. Notwithstanding Your Terms,
+#     this License (including the redistribution requirements in Section
+#     3.1) will continue to apply to the Work itself.
+
+#     3.3 Use Limitation. The Work and any derivative works thereof only
+#     may be used or intended for use non-commercially. Notwithstanding
+#     the foregoing, NVIDIA and its affiliates may use the Work and any
+#     derivative works commercially. As used herein, "non-commercially"
+#     means for research or evaluation purposes only.
+
+#     3.4 Patent Claims. If you bring or threaten to bring a patent claim
+#     against any Licensor (including any claim, cross-claim or
+#     counterclaim in a lawsuit) to enforce any patents that you allege
+#     are infringed by any Work, then your rights under this License from
+#     such Licensor (including the grant in Section 2.1) will terminate
+#     immediately.
+
+#     3.5 Trademarks. This License does not grant any rights to use any
+#     Licensor’s or its affiliates’ names, logos, or trademarks, except
+#     as necessary to reproduce the notices described in this License.
+
+#     3.6 Termination. If you violate any term of this License, then your
+#     rights under this License (including the grant in Section 2.1) will
+#     terminate immediately.
+
+# 4. Disclaimer of Warranty.
+
+# THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
+# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
+# NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
+# THIS LICENSE.
+
+# 5. Limitation of Liability.
+
+# EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
+# THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
+# SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
+# INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
+# OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
+# (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
+# LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
+# COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
+# THE POSSIBILITY OF SUCH DAMAGES.
+
+# =======================================================================
+
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+from annotator.uniformer.mmcv.utils import to_2tuple
+from ..utils import ext_loader
+
+upfirdn2d_ext = ext_loader.load_ext('_ext', ['upfirdn2d'])
+
+
+class UpFirDn2dBackward(Function):
+
+    @staticmethod
+    def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
+                in_size, out_size):
+
+        up_x, up_y = up
+        down_x, down_y = down
+        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+        grad_input = upfirdn2d_ext.upfirdn2d(
+            grad_output,
+            grad_kernel,
+            up_x=down_x,
+            up_y=down_y,
+            down_x=up_x,
+            down_y=up_y,
+            pad_x0=g_pad_x0,
+            pad_x1=g_pad_x1,
+            pad_y0=g_pad_y0,
+            pad_y1=g_pad_y1)
+        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
+                                     in_size[3])
+
+        ctx.save_for_backward(kernel)
+
+        pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+        ctx.up_x = up_x
+        ctx.up_y = up_y
+        ctx.down_x = down_x
+        ctx.down_y = down_y
+        ctx.pad_x0 = pad_x0
+        ctx.pad_x1 = pad_x1
+        ctx.pad_y0 = pad_y0
+        ctx.pad_y1 = pad_y1
+        ctx.in_size = in_size
+        ctx.out_size = out_size
+
+        return grad_input
+
+    @staticmethod
+    def backward(ctx, gradgrad_input):
+        kernel, = ctx.saved_tensors
+
+        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
+                                                ctx.in_size[3], 1)
+
+        gradgrad_out = upfirdn2d_ext.upfirdn2d(
+            gradgrad_input,
+            kernel,
+            up_x=ctx.up_x,
+            up_y=ctx.up_y,
+            down_x=ctx.down_x,
+            down_y=ctx.down_y,
+            pad_x0=ctx.pad_x0,
+            pad_x1=ctx.pad_x1,
+            pad_y0=ctx.pad_y0,
+            pad_y1=ctx.pad_y1)
+        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+        #                                  ctx.out_size[1], ctx.in_size[3])
+        gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
+                                         ctx.out_size[0], ctx.out_size[1])
+
+        return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+    @staticmethod
+    def forward(ctx, input, kernel, up, down, pad):
+        up_x, up_y = up
+        down_x, down_y = down
+        pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+        kernel_h, kernel_w = kernel.shape
+        batch, channel, in_h, in_w = input.shape
+        ctx.in_size = input.shape
+
+        input = input.reshape(-1, in_h, in_w, 1)
+
+        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+        ctx.out_size = (out_h, out_w)
+
+        ctx.up = (up_x, up_y)
+        ctx.down = (down_x, down_y)
+        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+        g_pad_x0 = kernel_w - pad_x0 - 1
+        g_pad_y0 = kernel_h - pad_y0 - 1
+        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+        out = upfirdn2d_ext.upfirdn2d(
+            input,
+            kernel,
+            up_x=up_x,
+            up_y=up_y,
+            down_x=down_x,
+            down_y=down_y,
+            pad_x0=pad_x0,
+            pad_x1=pad_x1,
+            pad_y0=pad_y0,
+            pad_y1=pad_y1)
+        # out = out.view(major, out_h, out_w, minor)
+        out = out.view(-1, channel, out_h, out_w)
+
+        return out
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        kernel, grad_kernel = ctx.saved_tensors
+
+        grad_input = UpFirDn2dBackward.apply(
+            grad_output,
+            kernel,
+            grad_kernel,
+            ctx.up,
+            ctx.down,
+            ctx.pad,
+            ctx.g_pad,
+            ctx.in_size,
+            ctx.out_size,
+        )
+
+        return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+    """UpFRIDn for 2d features.
+
+    UpFIRDn is short for upsample, apply FIR filter and downsample. More
+    details can be found in:
+    https://www.mathworks.com/help/signal/ref/upfirdn.html
+
+    Args:
+        input (Tensor): Tensor with shape of (n, c, h, w).
+        kernel (Tensor): Filter kernel.
+        up (int | tuple[int], optional): Upsampling factor. If given a number,
+            we will use this factor for the both height and width side.
+            Defaults to 1.
+        down (int | tuple[int], optional): Downsampling factor. If given a
+            number, we will use this factor for the both height and width side.
+            Defaults to 1.
+        pad (tuple[int], optional): Padding for tensors, (x_pad, y_pad) or
+            (x_pad_0, x_pad_1, y_pad_0, y_pad_1). Defaults to (0, 0).
+
+    Returns:
+        Tensor: Tensor after UpFIRDn.
+    """
+    if input.device.type == 'cpu':
+        if len(pad) == 2:
+            pad = (pad[0], pad[1], pad[0], pad[1])
+
+        up = to_2tuple(up)
+
+        down = to_2tuple(down)
+
+        out = upfirdn2d_native(input, kernel, up[0], up[1], down[0], down[1],
+                               pad[0], pad[1], pad[2], pad[3])
+    else:
+        _up = to_2tuple(up)
+
+        _down = to_2tuple(down)
+
+        if len(pad) == 4:
+            _pad = pad
+        elif len(pad) == 2:
+            _pad = (pad[0], pad[1], pad[0], pad[1])
+
+        out = UpFirDn2d.apply(input, kernel, _up, _down, _pad)
+
+    return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
+                     pad_y0, pad_y1):
+    _, channel, in_h, in_w = input.shape
+    input = input.reshape(-1, in_h, in_w, 1)
+
+    _, in_h, in_w, minor = input.shape
+    kernel_h, kernel_w = kernel.shape
+
+    out = input.view(-1, in_h, 1, in_w, 1, minor)
+    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+    out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+    out = F.pad(
+        out,
+        [0, 0,
+         max(pad_x0, 0),
+         max(pad_x1, 0),
+         max(pad_y0, 0),
+         max(pad_y1, 0)])
+    out = out[:,
+              max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
+              max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+    out = out.permute(0, 3, 1, 2)
+    out = out.reshape(
+        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+    out = F.conv2d(out, w)
+    out = out.reshape(
+        -1,
+        minor,
+        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+    )
+    out = out.permute(0, 2, 3, 1)
+    out = out[:, ::down_y, ::down_x, :]
+
+    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+    return out.view(-1, channel, out_h, out_w)
diff --git a/annotator/uniformer/mmcv/ops/voxelize.py b/annotator/uniformer/mmcv/ops/voxelize.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca3226a4fbcbfe58490fa2ea8e1c16b531214121
--- /dev/null
+++ b/annotator/uniformer/mmcv/ops/voxelize.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch import nn
+from torch.autograd import Function
+from torch.nn.modules.utils import _pair
+
+from ..utils import ext_loader
+
+ext_module = ext_loader.load_ext(
+    '_ext', ['dynamic_voxelize_forward', 'hard_voxelize_forward'])
+
+
+class _Voxelization(Function):
+
+    @staticmethod
+    def forward(ctx,
+                points,
+                voxel_size,
+                coors_range,
+                max_points=35,
+                max_voxels=20000):
+        """Convert kitti points(N, >=3) to voxels.
+
+        Args:
+            points (torch.Tensor): [N, ndim]. Points[:, :3] contain xyz points
+                and points[:, 3:] contain other information like reflectivity.
+            voxel_size (tuple or float): The size of voxel with the shape of
+                [3].
+            coors_range (tuple or float): The coordinate range of voxel with
+                the shape of [6].
+            max_points (int, optional): maximum points contained in a voxel. if
+                max_points=-1, it means using dynamic_voxelize. Default: 35.
+            max_voxels (int, optional): maximum voxels this function create.
+                for second, 20000 is a good choice. Users should shuffle points
+                before call this function because max_voxels may drop points.
+                Default: 20000.
+
+        Returns:
+            voxels_out (torch.Tensor): Output voxels with the shape of [M,
+                max_points, ndim]. Only contain points and returned when
+                max_points != -1.
+            coors_out (torch.Tensor): Output coordinates with the shape of
+                [M, 3].
+            num_points_per_voxel_out (torch.Tensor): Num points per voxel with
+                the shape of [M]. Only returned when max_points != -1.
+        """
+        if max_points == -1 or max_voxels == -1:
+            coors = points.new_zeros(size=(points.size(0), 3), dtype=torch.int)
+            ext_module.dynamic_voxelize_forward(points, coors, voxel_size,
+                                                coors_range, 3)
+            return coors
+        else:
+            voxels = points.new_zeros(
+                size=(max_voxels, max_points, points.size(1)))
+            coors = points.new_zeros(size=(max_voxels, 3), dtype=torch.int)
+            num_points_per_voxel = points.new_zeros(
+                size=(max_voxels, ), dtype=torch.int)
+            voxel_num = ext_module.hard_voxelize_forward(
+                points, voxels, coors, num_points_per_voxel, voxel_size,
+                coors_range, max_points, max_voxels, 3)
+            # select the valid voxels
+            voxels_out = voxels[:voxel_num]
+            coors_out = coors[:voxel_num]
+            num_points_per_voxel_out = num_points_per_voxel[:voxel_num]
+            return voxels_out, coors_out, num_points_per_voxel_out
+
+
+voxelization = _Voxelization.apply
+
+
+class Voxelization(nn.Module):
+    """Convert kitti points(N, >=3) to voxels.
+
+    Please refer to `PVCNN <https://arxiv.org/abs/1907.03739>`_ for more
+    details.
+
+    Args:
+        voxel_size (tuple or float): The size of voxel with the shape of [3].
+        point_cloud_range (tuple or float): The coordinate range of voxel with
+            the shape of [6].
+        max_num_points (int): maximum points contained in a voxel. if
+            max_points=-1, it means using dynamic_voxelize.
+        max_voxels (int, optional): maximum voxels this function create.
+            for second, 20000 is a good choice. Users should shuffle points
+            before call this function because max_voxels may drop points.
+            Default: 20000.
+    """
+
+    def __init__(self,
+                 voxel_size,
+                 point_cloud_range,
+                 max_num_points,
+                 max_voxels=20000):
+        super().__init__()
+
+        self.voxel_size = voxel_size
+        self.point_cloud_range = point_cloud_range
+        self.max_num_points = max_num_points
+        if isinstance(max_voxels, tuple):
+            self.max_voxels = max_voxels
+        else:
+            self.max_voxels = _pair(max_voxels)
+
+        point_cloud_range = torch.tensor(
+            point_cloud_range, dtype=torch.float32)
+        voxel_size = torch.tensor(voxel_size, dtype=torch.float32)
+        grid_size = (point_cloud_range[3:] -
+                     point_cloud_range[:3]) / voxel_size
+        grid_size = torch.round(grid_size).long()
+        input_feat_shape = grid_size[:2]
+        self.grid_size = grid_size
+        # the origin shape is as [x-len, y-len, z-len]
+        # [w, h, d] -> [d, h, w]
+        self.pcd_shape = [*input_feat_shape, 1][::-1]
+
+    def forward(self, input):
+        if self.training:
+            max_voxels = self.max_voxels[0]
+        else:
+            max_voxels = self.max_voxels[1]
+
+        return voxelization(input, self.voxel_size, self.point_cloud_range,
+                            self.max_num_points, max_voxels)
+
+    def __repr__(self):
+        s = self.__class__.__name__ + '('
+        s += 'voxel_size=' + str(self.voxel_size)
+        s += ', point_cloud_range=' + str(self.point_cloud_range)
+        s += ', max_num_points=' + str(self.max_num_points)
+        s += ', max_voxels=' + str(self.max_voxels)
+        s += ')'
+        return s
diff --git a/annotator/uniformer/mmcv/parallel/__init__.py b/annotator/uniformer/mmcv/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed2c17ad357742e423beeaf4d35db03fe9af469
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .collate import collate
+from .data_container import DataContainer
+from .data_parallel import MMDataParallel
+from .distributed import MMDistributedDataParallel
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter, scatter_kwargs
+from .utils import is_module_wrapper
+
+__all__ = [
+    'collate', 'DataContainer', 'MMDataParallel', 'MMDistributedDataParallel',
+    'scatter', 'scatter_kwargs', 'is_module_wrapper', 'MODULE_WRAPPERS'
+]
diff --git a/annotator/uniformer/mmcv/parallel/_functions.py b/annotator/uniformer/mmcv/parallel/_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b5a8a44483ab991411d07122b22a1d027e4be8e
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/_functions.py
@@ -0,0 +1,79 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import _get_stream
+
+
+def scatter(input, devices, streams=None):
+    """Scatters tensor across multiple GPUs."""
+    if streams is None:
+        streams = [None] * len(devices)
+
+    if isinstance(input, list):
+        chunk_size = (len(input) - 1) // len(devices) + 1
+        outputs = [
+            scatter(input[i], [devices[i // chunk_size]],
+                    [streams[i // chunk_size]]) for i in range(len(input))
+        ]
+        return outputs
+    elif isinstance(input, torch.Tensor):
+        output = input.contiguous()
+        # TODO: copy to a pinned buffer first (if copying from CPU)
+        stream = streams[0] if output.numel() > 0 else None
+        if devices != [-1]:
+            with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
+                output = output.cuda(devices[0], non_blocking=True)
+        else:
+            # unsqueeze the first dimension thus the tensor's shape is the
+            # same as those scattered with GPU.
+            output = output.unsqueeze(0)
+        return output
+    else:
+        raise Exception(f'Unknown type {type(input)}.')
+
+
+def synchronize_stream(output, devices, streams):
+    if isinstance(output, list):
+        chunk_size = len(output) // len(devices)
+        for i in range(len(devices)):
+            for j in range(chunk_size):
+                synchronize_stream(output[i * chunk_size + j], [devices[i]],
+                                   [streams[i]])
+    elif isinstance(output, torch.Tensor):
+        if output.numel() != 0:
+            with torch.cuda.device(devices[0]):
+                main_stream = torch.cuda.current_stream()
+                main_stream.wait_stream(streams[0])
+                output.record_stream(main_stream)
+    else:
+        raise Exception(f'Unknown type {type(output)}.')
+
+
+def get_input_device(input):
+    if isinstance(input, list):
+        for item in input:
+            input_device = get_input_device(item)
+            if input_device != -1:
+                return input_device
+        return -1
+    elif isinstance(input, torch.Tensor):
+        return input.get_device() if input.is_cuda else -1
+    else:
+        raise Exception(f'Unknown type {type(input)}.')
+
+
+class Scatter:
+
+    @staticmethod
+    def forward(target_gpus, input):
+        input_device = get_input_device(input)
+        streams = None
+        if input_device == -1 and target_gpus != [-1]:
+            # Perform CPU to GPU copies in a background stream
+            streams = [_get_stream(device) for device in target_gpus]
+
+        outputs = scatter(input, target_gpus, streams)
+        # Synchronize with the copy stream
+        if streams is not None:
+            synchronize_stream(outputs, target_gpus, streams)
+
+        return tuple(outputs)
diff --git a/annotator/uniformer/mmcv/parallel/collate.py b/annotator/uniformer/mmcv/parallel/collate.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad749197df21b0d74297548be5f66a696adebf7f
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/collate.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections.abc import Mapping, Sequence
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data.dataloader import default_collate
+
+from .data_container import DataContainer
+
+
+def collate(batch, samples_per_gpu=1):
+    """Puts each data field into a tensor/DataContainer with outer dimension
+    batch size.
+
+    Extend default_collate to add support for
+    :type:`~mmcv.parallel.DataContainer`. There are 3 cases.
+
+    1. cpu_only = True, e.g., meta data
+    2. cpu_only = False, stack = True, e.g., images tensors
+    3. cpu_only = False, stack = False, e.g., gt bboxes
+    """
+
+    if not isinstance(batch, Sequence):
+        raise TypeError(f'{batch.dtype} is not supported.')
+
+    if isinstance(batch[0], DataContainer):
+        stacked = []
+        if batch[0].cpu_only:
+            for i in range(0, len(batch), samples_per_gpu):
+                stacked.append(
+                    [sample.data for sample in batch[i:i + samples_per_gpu]])
+            return DataContainer(
+                stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
+        elif batch[0].stack:
+            for i in range(0, len(batch), samples_per_gpu):
+                assert isinstance(batch[i].data, torch.Tensor)
+
+                if batch[i].pad_dims is not None:
+                    ndim = batch[i].dim()
+                    assert ndim > batch[i].pad_dims
+                    max_shape = [0 for _ in range(batch[i].pad_dims)]
+                    for dim in range(1, batch[i].pad_dims + 1):
+                        max_shape[dim - 1] = batch[i].size(-dim)
+                    for sample in batch[i:i + samples_per_gpu]:
+                        for dim in range(0, ndim - batch[i].pad_dims):
+                            assert batch[i].size(dim) == sample.size(dim)
+                        for dim in range(1, batch[i].pad_dims + 1):
+                            max_shape[dim - 1] = max(max_shape[dim - 1],
+                                                     sample.size(-dim))
+                    padded_samples = []
+                    for sample in batch[i:i + samples_per_gpu]:
+                        pad = [0 for _ in range(batch[i].pad_dims * 2)]
+                        for dim in range(1, batch[i].pad_dims + 1):
+                            pad[2 * dim -
+                                1] = max_shape[dim - 1] - sample.size(-dim)
+                        padded_samples.append(
+                            F.pad(
+                                sample.data, pad, value=sample.padding_value))
+                    stacked.append(default_collate(padded_samples))
+                elif batch[i].pad_dims is None:
+                    stacked.append(
+                        default_collate([
+                            sample.data
+                            for sample in batch[i:i + samples_per_gpu]
+                        ]))
+                else:
+                    raise ValueError(
+                        'pad_dims should be either None or integers (1-3)')
+
+        else:
+            for i in range(0, len(batch), samples_per_gpu):
+                stacked.append(
+                    [sample.data for sample in batch[i:i + samples_per_gpu]])
+        return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
+    elif isinstance(batch[0], Sequence):
+        transposed = zip(*batch)
+        return [collate(samples, samples_per_gpu) for samples in transposed]
+    elif isinstance(batch[0], Mapping):
+        return {
+            key: collate([d[key] for d in batch], samples_per_gpu)
+            for key in batch[0]
+        }
+    else:
+        return default_collate(batch)
diff --git a/annotator/uniformer/mmcv/parallel/data_container.py b/annotator/uniformer/mmcv/parallel/data_container.py
new file mode 100644
index 0000000000000000000000000000000000000000..cedb0d32a51a1f575a622b38de2cee3ab4757821
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/data_container.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+
+import torch
+
+
+def assert_tensor_type(func):
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        if not isinstance(args[0].data, torch.Tensor):
+            raise AttributeError(
+                f'{args[0].__class__.__name__} has no attribute '
+                f'{func.__name__} for type {args[0].datatype}')
+        return func(*args, **kwargs)
+
+    return wrapper
+
+
+class DataContainer:
+    """A container for any type of objects.
+
+    Typically tensors will be stacked in the collate function and sliced along
+    some dimension in the scatter function. This behavior has some limitations.
+    1. All tensors have to be the same size.
+    2. Types are limited (numpy array or Tensor).
+
+    We design `DataContainer` and `MMDataParallel` to overcome these
+    limitations. The behavior can be either of the following.
+
+    - copy to GPU, pad all tensors to the same size and stack them
+    - copy to GPU without stacking
+    - leave the objects as is and pass it to the model
+    - pad_dims specifies the number of last few dimensions to do padding
+    """
+
+    def __init__(self,
+                 data,
+                 stack=False,
+                 padding_value=0,
+                 cpu_only=False,
+                 pad_dims=2):
+        self._data = data
+        self._cpu_only = cpu_only
+        self._stack = stack
+        self._padding_value = padding_value
+        assert pad_dims in [None, 1, 2, 3]
+        self._pad_dims = pad_dims
+
+    def __repr__(self):
+        return f'{self.__class__.__name__}({repr(self.data)})'
+
+    def __len__(self):
+        return len(self._data)
+
+    @property
+    def data(self):
+        return self._data
+
+    @property
+    def datatype(self):
+        if isinstance(self.data, torch.Tensor):
+            return self.data.type()
+        else:
+            return type(self.data)
+
+    @property
+    def cpu_only(self):
+        return self._cpu_only
+
+    @property
+    def stack(self):
+        return self._stack
+
+    @property
+    def padding_value(self):
+        return self._padding_value
+
+    @property
+    def pad_dims(self):
+        return self._pad_dims
+
+    @assert_tensor_type
+    def size(self, *args, **kwargs):
+        return self.data.size(*args, **kwargs)
+
+    @assert_tensor_type
+    def dim(self):
+        return self.data.dim()
diff --git a/annotator/uniformer/mmcv/parallel/data_parallel.py b/annotator/uniformer/mmcv/parallel/data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..79b5f69b654cf647dc7ae9174223781ab5c607d2
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/data_parallel.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from itertools import chain
+
+from torch.nn.parallel import DataParallel
+
+from .scatter_gather import scatter_kwargs
+
+
+class MMDataParallel(DataParallel):
+    """The DataParallel module that supports DataContainer.
+
+    MMDataParallel has two main differences with PyTorch DataParallel:
+
+    - It supports a custom type :class:`DataContainer` which allows more
+      flexible control of input data during both GPU and CPU inference.
+    - It implement two more APIs ``train_step()`` and ``val_step()``.
+
+    Args:
+        module (:class:`nn.Module`): Module to be encapsulated.
+        device_ids (list[int]): Device IDS of modules to be scattered to.
+            Defaults to None when GPU is not available.
+        output_device (str | int): Device ID for output. Defaults to None.
+        dim (int): Dimension used to scatter the data. Defaults to 0.
+    """
+
+    def __init__(self, *args, dim=0, **kwargs):
+        super(MMDataParallel, self).__init__(*args, dim=dim, **kwargs)
+        self.dim = dim
+
+    def forward(self, *inputs, **kwargs):
+        """Override the original forward function.
+
+        The main difference lies in the CPU inference where the data in
+        :class:`DataContainers` will still be gathered.
+        """
+        if not self.device_ids:
+            # We add the following line thus the module could gather and
+            # convert data containers as those in GPU inference
+            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+            return self.module(*inputs[0], **kwargs[0])
+        else:
+            return super().forward(*inputs, **kwargs)
+
+    def scatter(self, inputs, kwargs, device_ids):
+        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+    def train_step(self, *inputs, **kwargs):
+        if not self.device_ids:
+            # We add the following line thus the module could gather and
+            # convert data containers as those in GPU inference
+            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+            return self.module.train_step(*inputs[0], **kwargs[0])
+
+        assert len(self.device_ids) == 1, \
+            ('MMDataParallel only supports single GPU training, if you need to'
+             ' train with multiple GPUs, please use MMDistributedDataParallel'
+             'instead.')
+
+        for t in chain(self.module.parameters(), self.module.buffers()):
+            if t.device != self.src_device_obj:
+                raise RuntimeError(
+                    'module must have its parameters and buffers '
+                    f'on device {self.src_device_obj} (device_ids[0]) but '
+                    f'found one of them on device: {t.device}')
+
+        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+        return self.module.train_step(*inputs[0], **kwargs[0])
+
+    def val_step(self, *inputs, **kwargs):
+        if not self.device_ids:
+            # We add the following line thus the module could gather and
+            # convert data containers as those in GPU inference
+            inputs, kwargs = self.scatter(inputs, kwargs, [-1])
+            return self.module.val_step(*inputs[0], **kwargs[0])
+
+        assert len(self.device_ids) == 1, \
+            ('MMDataParallel only supports single GPU training, if you need to'
+             ' train with multiple GPUs, please use MMDistributedDataParallel'
+             ' instead.')
+
+        for t in chain(self.module.parameters(), self.module.buffers()):
+            if t.device != self.src_device_obj:
+                raise RuntimeError(
+                    'module must have its parameters and buffers '
+                    f'on device {self.src_device_obj} (device_ids[0]) but '
+                    f'found one of them on device: {t.device}')
+
+        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+        return self.module.val_step(*inputs[0], **kwargs[0])
diff --git a/annotator/uniformer/mmcv/parallel/distributed.py b/annotator/uniformer/mmcv/parallel/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e4c27903db58a54d37ea1ed9ec0104098b486f2
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/distributed.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel.distributed import (DistributedDataParallel,
+                                           _find_tensors)
+
+from annotator.uniformer.mmcv import print_log
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from .scatter_gather import scatter_kwargs
+
+
+class MMDistributedDataParallel(DistributedDataParallel):
+    """The DDP module that supports DataContainer.
+
+    MMDDP has two main differences with PyTorch DDP:
+
+    - It supports a custom type :class:`DataContainer` which allows more
+      flexible control of input data.
+    - It implement two APIs ``train_step()`` and ``val_step()``.
+    """
+
+    def to_kwargs(self, inputs, kwargs, device_id):
+        # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8
+        # to move all tensors to device_id
+        return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim)
+
+    def scatter(self, inputs, kwargs, device_ids):
+        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+    def train_step(self, *inputs, **kwargs):
+        """train_step() API for module wrapped by DistributedDataParallel.
+
+        This method is basically the same as
+        ``DistributedDataParallel.forward()``, while replacing
+        ``self.module.forward()`` with ``self.module.train_step()``.
+        It is compatible with PyTorch 1.1 - 1.5.
+        """
+
+        # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+        # end of backward to the beginning of forward.
+        if ('parrots' not in TORCH_VERSION
+                and digit_version(TORCH_VERSION) >= digit_version('1.7')
+                and self.reducer._rebuild_buckets()):
+            print_log(
+                'Reducer buckets have been rebuilt in this iteration.',
+                logger='mmcv')
+
+        if getattr(self, 'require_forward_param_sync', True):
+            self._sync_params()
+        if self.device_ids:
+            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+            if len(self.device_ids) == 1:
+                output = self.module.train_step(*inputs[0], **kwargs[0])
+            else:
+                outputs = self.parallel_apply(
+                    self._module_copies[:len(inputs)], inputs, kwargs)
+                output = self.gather(outputs, self.output_device)
+        else:
+            output = self.module.train_step(*inputs, **kwargs)
+
+        if torch.is_grad_enabled() and getattr(
+                self, 'require_backward_grad_sync', True):
+            if self.find_unused_parameters:
+                self.reducer.prepare_for_backward(list(_find_tensors(output)))
+            else:
+                self.reducer.prepare_for_backward([])
+        else:
+            if ('parrots' not in TORCH_VERSION
+                    and digit_version(TORCH_VERSION) > digit_version('1.2')):
+                self.require_forward_param_sync = False
+        return output
+
+    def val_step(self, *inputs, **kwargs):
+        """val_step() API for module wrapped by DistributedDataParallel.
+
+        This method is basically the same as
+        ``DistributedDataParallel.forward()``, while replacing
+        ``self.module.forward()`` with ``self.module.val_step()``.
+        It is compatible with PyTorch 1.1 - 1.5.
+        """
+        # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
+        # end of backward to the beginning of forward.
+        if ('parrots' not in TORCH_VERSION
+                and digit_version(TORCH_VERSION) >= digit_version('1.7')
+                and self.reducer._rebuild_buckets()):
+            print_log(
+                'Reducer buckets have been rebuilt in this iteration.',
+                logger='mmcv')
+
+        if getattr(self, 'require_forward_param_sync', True):
+            self._sync_params()
+        if self.device_ids:
+            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+            if len(self.device_ids) == 1:
+                output = self.module.val_step(*inputs[0], **kwargs[0])
+            else:
+                outputs = self.parallel_apply(
+                    self._module_copies[:len(inputs)], inputs, kwargs)
+                output = self.gather(outputs, self.output_device)
+        else:
+            output = self.module.val_step(*inputs, **kwargs)
+
+        if torch.is_grad_enabled() and getattr(
+                self, 'require_backward_grad_sync', True):
+            if self.find_unused_parameters:
+                self.reducer.prepare_for_backward(list(_find_tensors(output)))
+            else:
+                self.reducer.prepare_for_backward([])
+        else:
+            if ('parrots' not in TORCH_VERSION
+                    and digit_version(TORCH_VERSION) > digit_version('1.2')):
+                self.require_forward_param_sync = False
+        return output
diff --git a/annotator/uniformer/mmcv/parallel/distributed_deprecated.py b/annotator/uniformer/mmcv/parallel/distributed_deprecated.py
new file mode 100644
index 0000000000000000000000000000000000000000..676937a2085d4da20fa87923041a200fca6214eb
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/distributed_deprecated.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+                          _unflatten_dense_tensors)
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from .registry import MODULE_WRAPPERS
+from .scatter_gather import scatter_kwargs
+
+
+@MODULE_WRAPPERS.register_module()
+class MMDistributedDataParallel(nn.Module):
+
+    def __init__(self,
+                 module,
+                 dim=0,
+                 broadcast_buffers=True,
+                 bucket_cap_mb=25):
+        super(MMDistributedDataParallel, self).__init__()
+        self.module = module
+        self.dim = dim
+        self.broadcast_buffers = broadcast_buffers
+
+        self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
+        self._sync_params()
+
+    def _dist_broadcast_coalesced(self, tensors, buffer_size):
+        for tensors in _take_tensors(tensors, buffer_size):
+            flat_tensors = _flatten_dense_tensors(tensors)
+            dist.broadcast(flat_tensors, 0)
+            for tensor, synced in zip(
+                    tensors, _unflatten_dense_tensors(flat_tensors, tensors)):
+                tensor.copy_(synced)
+
+    def _sync_params(self):
+        module_states = list(self.module.state_dict().values())
+        if len(module_states) > 0:
+            self._dist_broadcast_coalesced(module_states,
+                                           self.broadcast_bucket_size)
+        if self.broadcast_buffers:
+            if (TORCH_VERSION != 'parrots'
+                    and digit_version(TORCH_VERSION) < digit_version('1.0')):
+                buffers = [b.data for b in self.module._all_buffers()]
+            else:
+                buffers = [b.data for b in self.module.buffers()]
+            if len(buffers) > 0:
+                self._dist_broadcast_coalesced(buffers,
+                                               self.broadcast_bucket_size)
+
+    def scatter(self, inputs, kwargs, device_ids):
+        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+    def forward(self, *inputs, **kwargs):
+        inputs, kwargs = self.scatter(inputs, kwargs,
+                                      [torch.cuda.current_device()])
+        return self.module(*inputs[0], **kwargs[0])
+
+    def train_step(self, *inputs, **kwargs):
+        inputs, kwargs = self.scatter(inputs, kwargs,
+                                      [torch.cuda.current_device()])
+        output = self.module.train_step(*inputs[0], **kwargs[0])
+        return output
+
+    def val_step(self, *inputs, **kwargs):
+        inputs, kwargs = self.scatter(inputs, kwargs,
+                                      [torch.cuda.current_device()])
+        output = self.module.val_step(*inputs[0], **kwargs[0])
+        return output
diff --git a/annotator/uniformer/mmcv/parallel/registry.py b/annotator/uniformer/mmcv/parallel/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..a204a07fba10e614223f090d1a57cf9c4d74d4a1
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/registry.py
@@ -0,0 +1,8 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+from annotator.uniformer.mmcv.utils import Registry
+
+MODULE_WRAPPERS = Registry('module wrapper')
+MODULE_WRAPPERS.register_module(module=DataParallel)
+MODULE_WRAPPERS.register_module(module=DistributedDataParallel)
diff --git a/annotator/uniformer/mmcv/parallel/scatter_gather.py b/annotator/uniformer/mmcv/parallel/scatter_gather.py
new file mode 100644
index 0000000000000000000000000000000000000000..900ff88566f8f14830590459dc4fd16d4b382e47
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/scatter_gather.py
@@ -0,0 +1,59 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from torch.nn.parallel._functions import Scatter as OrigScatter
+
+from ._functions import Scatter
+from .data_container import DataContainer
+
+
+def scatter(inputs, target_gpus, dim=0):
+    """Scatter inputs to target gpus.
+
+    The only difference from original :func:`scatter` is to add support for
+    :type:`~mmcv.parallel.DataContainer`.
+    """
+
+    def scatter_map(obj):
+        if isinstance(obj, torch.Tensor):
+            if target_gpus != [-1]:
+                return OrigScatter.apply(target_gpus, None, dim, obj)
+            else:
+                # for CPU inference we use self-implemented scatter
+                return Scatter.forward(target_gpus, obj)
+        if isinstance(obj, DataContainer):
+            if obj.cpu_only:
+                return obj.data
+            else:
+                return Scatter.forward(target_gpus, obj.data)
+        if isinstance(obj, tuple) and len(obj) > 0:
+            return list(zip(*map(scatter_map, obj)))
+        if isinstance(obj, list) and len(obj) > 0:
+            out = list(map(list, zip(*map(scatter_map, obj))))
+            return out
+        if isinstance(obj, dict) and len(obj) > 0:
+            out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
+            return out
+        return [obj for targets in target_gpus]
+
+    # After scatter_map is called, a scatter_map cell will exist. This cell
+    # has a reference to the actual function scatter_map, which has references
+    # to a closure that has a reference to the scatter_map cell (because the
+    # fn is recursive). To avoid this reference cycle, we set the function to
+    # None, clearing the cell
+    try:
+        return scatter_map(inputs)
+    finally:
+        scatter_map = None
+
+
+def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
+    """Scatter with support for kwargs dictionary."""
+    inputs = scatter(inputs, target_gpus, dim) if inputs else []
+    kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
+    if len(inputs) < len(kwargs):
+        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
+    elif len(kwargs) < len(inputs):
+        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
+    inputs = tuple(inputs)
+    kwargs = tuple(kwargs)
+    return inputs, kwargs
diff --git a/annotator/uniformer/mmcv/parallel/utils.py b/annotator/uniformer/mmcv/parallel/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f5712cb42c38a2e8563bf563efb6681383cab9b
--- /dev/null
+++ b/annotator/uniformer/mmcv/parallel/utils.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .registry import MODULE_WRAPPERS
+
+
+def is_module_wrapper(module):
+    """Check if a module is a module wrapper.
+
+    The following 3 modules in MMCV (and their subclasses) are regarded as
+    module wrappers: DataParallel, DistributedDataParallel,
+    MMDistributedDataParallel (the deprecated version). You may add you own
+    module wrapper by registering it to mmcv.parallel.MODULE_WRAPPERS.
+
+    Args:
+        module (nn.Module): The module to be checked.
+
+    Returns:
+        bool: True if the input module is a module wrapper.
+    """
+    module_wrappers = tuple(MODULE_WRAPPERS.module_dict.values())
+    return isinstance(module, module_wrappers)
diff --git a/annotator/uniformer/mmcv/runner/__init__.py b/annotator/uniformer/mmcv/runner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52e4b48d383a84a055dcd7f6236f6e8e58eab924
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/__init__.py
@@ -0,0 +1,47 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_module import BaseModule, ModuleList, Sequential
+from .base_runner import BaseRunner
+from .builder import RUNNERS, build_runner
+from .checkpoint import (CheckpointLoader, _load_checkpoint,
+                         _load_checkpoint_with_prefix, load_checkpoint,
+                         load_state_dict, save_checkpoint, weights_to_cpu)
+from .default_constructor import DefaultRunnerConstructor
+from .dist_utils import (allreduce_grads, allreduce_params, get_dist_info,
+                         init_dist, master_only)
+from .epoch_based_runner import EpochBasedRunner, Runner
+from .fp16_utils import LossScaler, auto_fp16, force_fp32, wrap_fp16_model
+from .hooks import (HOOKS, CheckpointHook, ClosureHook, DistEvalHook,
+                    DistSamplerSeedHook, DvcliveLoggerHook, EMAHook, EvalHook,
+                    Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+                    GradientCumulativeOptimizerHook, Hook, IterTimerHook,
+                    LoggerHook, LrUpdaterHook, MlflowLoggerHook,
+                    NeptuneLoggerHook, OptimizerHook, PaviLoggerHook,
+                    SyncBuffersHook, TensorboardLoggerHook, TextLoggerHook,
+                    WandbLoggerHook)
+from .iter_based_runner import IterBasedRunner, IterLoader
+from .log_buffer import LogBuffer
+from .optimizer import (OPTIMIZER_BUILDERS, OPTIMIZERS,
+                        DefaultOptimizerConstructor, build_optimizer,
+                        build_optimizer_constructor)
+from .priority import Priority, get_priority
+from .utils import get_host_info, get_time_str, obj_from_dict, set_random_seed
+
+__all__ = [
+    'BaseRunner', 'Runner', 'EpochBasedRunner', 'IterBasedRunner', 'LogBuffer',
+    'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+    'OptimizerHook', 'IterTimerHook', 'DistSamplerSeedHook', 'LoggerHook',
+    'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+    'NeptuneLoggerHook', 'WandbLoggerHook', 'MlflowLoggerHook',
+    'DvcliveLoggerHook', '_load_checkpoint', 'load_state_dict',
+    'load_checkpoint', 'weights_to_cpu', 'save_checkpoint', 'Priority',
+    'get_priority', 'get_host_info', 'get_time_str', 'obj_from_dict',
+    'init_dist', 'get_dist_info', 'master_only', 'OPTIMIZER_BUILDERS',
+    'OPTIMIZERS', 'DefaultOptimizerConstructor', 'build_optimizer',
+    'build_optimizer_constructor', 'IterLoader', 'set_random_seed',
+    'auto_fp16', 'force_fp32', 'wrap_fp16_model', 'Fp16OptimizerHook',
+    'SyncBuffersHook', 'EMAHook', 'build_runner', 'RUNNERS', 'allreduce_grads',
+    'allreduce_params', 'LossScaler', 'CheckpointLoader', 'BaseModule',
+    '_load_checkpoint_with_prefix', 'EvalHook', 'DistEvalHook', 'Sequential',
+    'ModuleList', 'GradientCumulativeOptimizerHook',
+    'GradientCumulativeFp16OptimizerHook', 'DefaultRunnerConstructor'
+]
diff --git a/annotator/uniformer/mmcv/runner/base_module.py b/annotator/uniformer/mmcv/runner/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..617fad9bb89f10a9a0911d962dfb3bc8f3a3628c
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/base_module.py
@@ -0,0 +1,195 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import warnings
+from abc import ABCMeta
+from collections import defaultdict
+from logging import FileHandler
+
+import torch.nn as nn
+
+from annotator.uniformer.mmcv.runner.dist_utils import master_only
+from annotator.uniformer.mmcv.utils.logging import get_logger, logger_initialized, print_log
+
+
+class BaseModule(nn.Module, metaclass=ABCMeta):
+    """Base module for all modules in openmmlab.
+
+    ``BaseModule`` is a wrapper of ``torch.nn.Module`` with additional
+    functionality of parameter initialization. Compared with
+    ``torch.nn.Module``, ``BaseModule`` mainly adds three attributes.
+
+        - ``init_cfg``: the config to control the initialization.
+        - ``init_weights``: The function of parameter
+            initialization and recording initialization
+            information.
+        - ``_params_init_info``: Used to track the parameter
+            initialization information. This attribute only
+            exists during executing the ``init_weights``.
+
+    Args:
+        init_cfg (dict, optional): Initialization config dict.
+    """
+
+    def __init__(self, init_cfg=None):
+        """Initialize BaseModule, inherited from `torch.nn.Module`"""
+
+        # NOTE init_cfg can be defined in different levels, but init_cfg
+        # in low levels has a higher priority.
+
+        super(BaseModule, self).__init__()
+        # define default value of init_cfg instead of hard code
+        # in init_weights() function
+        self._is_init = False
+
+        self.init_cfg = copy.deepcopy(init_cfg)
+
+        # Backward compatibility in derived classes
+        # if pretrained is not None:
+        #     warnings.warn('DeprecationWarning: pretrained is a deprecated \
+        #         key, please consider using init_cfg')
+        #     self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+
+    @property
+    def is_init(self):
+        return self._is_init
+
+    def init_weights(self):
+        """Initialize the weights."""
+
+        is_top_level_module = False
+        # check if it is top-level module
+        if not hasattr(self, '_params_init_info'):
+            # The `_params_init_info` is used to record the initialization
+            # information of the parameters
+            # the key should be the obj:`nn.Parameter` of model and the value
+            # should be a dict containing
+            # - init_info (str): The string that describes the initialization.
+            # - tmp_mean_value (FloatTensor): The mean of the parameter,
+            #       which indicates whether the parameter has been modified.
+            # this attribute would be deleted after all parameters
+            # is initialized.
+            self._params_init_info = defaultdict(dict)
+            is_top_level_module = True
+
+            # Initialize the `_params_init_info`,
+            # When detecting the `tmp_mean_value` of
+            # the corresponding parameter is changed, update related
+            # initialization information
+            for name, param in self.named_parameters():
+                self._params_init_info[param][
+                    'init_info'] = f'The value is the same before and ' \
+                                   f'after calling `init_weights` ' \
+                                   f'of {self.__class__.__name__} '
+                self._params_init_info[param][
+                    'tmp_mean_value'] = param.data.mean()
+
+            # pass `params_init_info` to all submodules
+            # All submodules share the same `params_init_info`,
+            # so it will be updated when parameters are
+            # modified at any level of the model.
+            for sub_module in self.modules():
+                sub_module._params_init_info = self._params_init_info
+
+        # Get the initialized logger, if not exist,
+        # create a logger named `mmcv`
+        logger_names = list(logger_initialized.keys())
+        logger_name = logger_names[0] if logger_names else 'mmcv'
+
+        from ..cnn import initialize
+        from ..cnn.utils.weight_init import update_init_info
+        module_name = self.__class__.__name__
+        if not self._is_init:
+            if self.init_cfg:
+                print_log(
+                    f'initialize {module_name} with init_cfg {self.init_cfg}',
+                    logger=logger_name)
+                initialize(self, self.init_cfg)
+                if isinstance(self.init_cfg, dict):
+                    # prevent the parameters of
+                    # the pre-trained model
+                    # from being overwritten by
+                    # the `init_weights`
+                    if self.init_cfg['type'] == 'Pretrained':
+                        return
+
+            for m in self.children():
+                if hasattr(m, 'init_weights'):
+                    m.init_weights()
+                    # users may overload the `init_weights`
+                    update_init_info(
+                        m,
+                        init_info=f'Initialized by '
+                        f'user-defined `init_weights`'
+                        f' in {m.__class__.__name__} ')
+
+            self._is_init = True
+        else:
+            warnings.warn(f'init_weights of {self.__class__.__name__} has '
+                          f'been called more than once.')
+
+        if is_top_level_module:
+            self._dump_init_info(logger_name)
+
+            for sub_module in self.modules():
+                del sub_module._params_init_info
+
+    @master_only
+    def _dump_init_info(self, logger_name):
+        """Dump the initialization information to a file named
+        `initialization.log.json` in workdir.
+
+        Args:
+            logger_name (str): The name of logger.
+        """
+
+        logger = get_logger(logger_name)
+
+        with_file_handler = False
+        # dump the information to the logger file if there is a `FileHandler`
+        for handler in logger.handlers:
+            if isinstance(handler, FileHandler):
+                handler.stream.write(
+                    'Name of parameter - Initialization information\n')
+                for name, param in self.named_parameters():
+                    handler.stream.write(
+                        f'\n{name} - {param.shape}: '
+                        f"\n{self._params_init_info[param]['init_info']} \n")
+                handler.stream.flush()
+                with_file_handler = True
+        if not with_file_handler:
+            for name, param in self.named_parameters():
+                print_log(
+                    f'\n{name} - {param.shape}: '
+                    f"\n{self._params_init_info[param]['init_info']} \n ",
+                    logger=logger_name)
+
+    def __repr__(self):
+        s = super().__repr__()
+        if self.init_cfg:
+            s += f'\ninit_cfg={self.init_cfg}'
+        return s
+
+
+class Sequential(BaseModule, nn.Sequential):
+    """Sequential module in openmmlab.
+
+    Args:
+        init_cfg (dict, optional): Initialization config dict.
+    """
+
+    def __init__(self, *args, init_cfg=None):
+        BaseModule.__init__(self, init_cfg)
+        nn.Sequential.__init__(self, *args)
+
+
+class ModuleList(BaseModule, nn.ModuleList):
+    """ModuleList in openmmlab.
+
+    Args:
+        modules (iterable, optional): an iterable of modules to add.
+        init_cfg (dict, optional): Initialization config dict.
+    """
+
+    def __init__(self, modules=None, init_cfg=None):
+        BaseModule.__init__(self, init_cfg)
+        nn.ModuleList.__init__(self, modules)
diff --git a/annotator/uniformer/mmcv/runner/base_runner.py b/annotator/uniformer/mmcv/runner/base_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..4928db0a73b56fe0218a4bf66ec4ffa082d31ccc
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/base_runner.py
@@ -0,0 +1,542 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import logging
+import os.path as osp
+import warnings
+from abc import ABCMeta, abstractmethod
+
+import torch
+from torch.optim import Optimizer
+
+import annotator.uniformer.mmcv as mmcv
+from ..parallel import is_module_wrapper
+from .checkpoint import load_checkpoint
+from .dist_utils import get_dist_info
+from .hooks import HOOKS, Hook
+from .log_buffer import LogBuffer
+from .priority import Priority, get_priority
+from .utils import get_time_str
+
+
+class BaseRunner(metaclass=ABCMeta):
+    """The base class of Runner, a training helper for PyTorch.
+
+    All subclasses should implement the following APIs:
+
+    - ``run()``
+    - ``train()``
+    - ``val()``
+    - ``save_checkpoint()``
+
+    Args:
+        model (:obj:`torch.nn.Module`): The model to be run.
+        batch_processor (callable): A callable method that process a data
+            batch. The interface of this method should be
+            `batch_processor(model, data, train_mode) -> dict`
+        optimizer (dict or :obj:`torch.optim.Optimizer`): It can be either an
+            optimizer (in most cases) or a dict of optimizers (in models that
+            requires more than one optimizer, e.g., GAN).
+        work_dir (str, optional): The working directory to save checkpoints
+            and logs. Defaults to None.
+        logger (:obj:`logging.Logger`): Logger used during training.
+             Defaults to None. (The default value is just for backward
+             compatibility)
+        meta (dict | None): A dict records some import information such as
+            environment info and seed, which will be logged in logger hook.
+            Defaults to None.
+        max_epochs (int, optional): Total training epochs.
+        max_iters (int, optional): Total training iterations.
+    """
+
+    def __init__(self,
+                 model,
+                 batch_processor=None,
+                 optimizer=None,
+                 work_dir=None,
+                 logger=None,
+                 meta=None,
+                 max_iters=None,
+                 max_epochs=None):
+        if batch_processor is not None:
+            if not callable(batch_processor):
+                raise TypeError('batch_processor must be callable, '
+                                f'but got {type(batch_processor)}')
+            warnings.warn('batch_processor is deprecated, please implement '
+                          'train_step() and val_step() in the model instead.')
+            # raise an error is `batch_processor` is not None and
+            # `model.train_step()` exists.
+            if is_module_wrapper(model):
+                _model = model.module
+            else:
+                _model = model
+            if hasattr(_model, 'train_step') or hasattr(_model, 'val_step'):
+                raise RuntimeError(
+                    'batch_processor and model.train_step()/model.val_step() '
+                    'cannot be both available.')
+        else:
+            assert hasattr(model, 'train_step')
+
+        # check the type of `optimizer`
+        if isinstance(optimizer, dict):
+            for name, optim in optimizer.items():
+                if not isinstance(optim, Optimizer):
+                    raise TypeError(
+                        f'optimizer must be a dict of torch.optim.Optimizers, '
+                        f'but optimizer["{name}"] is a {type(optim)}')
+        elif not isinstance(optimizer, Optimizer) and optimizer is not None:
+            raise TypeError(
+                f'optimizer must be a torch.optim.Optimizer object '
+                f'or dict or None, but got {type(optimizer)}')
+
+        # check the type of `logger`
+        if not isinstance(logger, logging.Logger):
+            raise TypeError(f'logger must be a logging.Logger object, '
+                            f'but got {type(logger)}')
+
+        # check the type of `meta`
+        if meta is not None and not isinstance(meta, dict):
+            raise TypeError(
+                f'meta must be a dict or None, but got {type(meta)}')
+
+        self.model = model
+        self.batch_processor = batch_processor
+        self.optimizer = optimizer
+        self.logger = logger
+        self.meta = meta
+        # create work_dir
+        if mmcv.is_str(work_dir):
+            self.work_dir = osp.abspath(work_dir)
+            mmcv.mkdir_or_exist(self.work_dir)
+        elif work_dir is None:
+            self.work_dir = None
+        else:
+            raise TypeError('"work_dir" must be a str or None')
+
+        # get model name from the model class
+        if hasattr(self.model, 'module'):
+            self._model_name = self.model.module.__class__.__name__
+        else:
+            self._model_name = self.model.__class__.__name__
+
+        self._rank, self._world_size = get_dist_info()
+        self.timestamp = get_time_str()
+        self.mode = None
+        self._hooks = []
+        self._epoch = 0
+        self._iter = 0
+        self._inner_iter = 0
+
+        if max_epochs is not None and max_iters is not None:
+            raise ValueError(
+                'Only one of `max_epochs` or `max_iters` can be set.')
+
+        self._max_epochs = max_epochs
+        self._max_iters = max_iters
+        # TODO: Redesign LogBuffer, it is not flexible and elegant enough
+        self.log_buffer = LogBuffer()
+
+    @property
+    def model_name(self):
+        """str: Name of the model, usually the module class name."""
+        return self._model_name
+
+    @property
+    def rank(self):
+        """int: Rank of current process. (distributed training)"""
+        return self._rank
+
+    @property
+    def world_size(self):
+        """int: Number of processes participating in the job.
+        (distributed training)"""
+        return self._world_size
+
+    @property
+    def hooks(self):
+        """list[:obj:`Hook`]: A list of registered hooks."""
+        return self._hooks
+
+    @property
+    def epoch(self):
+        """int: Current epoch."""
+        return self._epoch
+
+    @property
+    def iter(self):
+        """int: Current iteration."""
+        return self._iter
+
+    @property
+    def inner_iter(self):
+        """int: Iteration in an epoch."""
+        return self._inner_iter
+
+    @property
+    def max_epochs(self):
+        """int: Maximum training epochs."""
+        return self._max_epochs
+
+    @property
+    def max_iters(self):
+        """int: Maximum training iterations."""
+        return self._max_iters
+
+    @abstractmethod
+    def train(self):
+        pass
+
+    @abstractmethod
+    def val(self):
+        pass
+
+    @abstractmethod
+    def run(self, data_loaders, workflow, **kwargs):
+        pass
+
+    @abstractmethod
+    def save_checkpoint(self,
+                        out_dir,
+                        filename_tmpl,
+                        save_optimizer=True,
+                        meta=None,
+                        create_symlink=True):
+        pass
+
+    def current_lr(self):
+        """Get current learning rates.
+
+        Returns:
+            list[float] | dict[str, list[float]]: Current learning rates of all
+                param groups. If the runner has a dict of optimizers, this
+                method will return a dict.
+        """
+        if isinstance(self.optimizer, torch.optim.Optimizer):
+            lr = [group['lr'] for group in self.optimizer.param_groups]
+        elif isinstance(self.optimizer, dict):
+            lr = dict()
+            for name, optim in self.optimizer.items():
+                lr[name] = [group['lr'] for group in optim.param_groups]
+        else:
+            raise RuntimeError(
+                'lr is not applicable because optimizer does not exist.')
+        return lr
+
+    def current_momentum(self):
+        """Get current momentums.
+
+        Returns:
+            list[float] | dict[str, list[float]]: Current momentums of all
+                param groups. If the runner has a dict of optimizers, this
+                method will return a dict.
+        """
+
+        def _get_momentum(optimizer):
+            momentums = []
+            for group in optimizer.param_groups:
+                if 'momentum' in group.keys():
+                    momentums.append(group['momentum'])
+                elif 'betas' in group.keys():
+                    momentums.append(group['betas'][0])
+                else:
+                    momentums.append(0)
+            return momentums
+
+        if self.optimizer is None:
+            raise RuntimeError(
+                'momentum is not applicable because optimizer does not exist.')
+        elif isinstance(self.optimizer, torch.optim.Optimizer):
+            momentums = _get_momentum(self.optimizer)
+        elif isinstance(self.optimizer, dict):
+            momentums = dict()
+            for name, optim in self.optimizer.items():
+                momentums[name] = _get_momentum(optim)
+        return momentums
+
+    def register_hook(self, hook, priority='NORMAL'):
+        """Register a hook into the hook list.
+
+        The hook will be inserted into a priority queue, with the specified
+        priority (See :class:`Priority` for details of priorities).
+        For hooks with the same priority, they will be triggered in the same
+        order as they are registered.
+
+        Args:
+            hook (:obj:`Hook`): The hook to be registered.
+            priority (int or str or :obj:`Priority`): Hook priority.
+                Lower value means higher priority.
+        """
+        assert isinstance(hook, Hook)
+        if hasattr(hook, 'priority'):
+            raise ValueError('"priority" is a reserved attribute for hooks')
+        priority = get_priority(priority)
+        hook.priority = priority
+        # insert the hook to a sorted list
+        inserted = False
+        for i in range(len(self._hooks) - 1, -1, -1):
+            if priority >= self._hooks[i].priority:
+                self._hooks.insert(i + 1, hook)
+                inserted = True
+                break
+        if not inserted:
+            self._hooks.insert(0, hook)
+
+    def register_hook_from_cfg(self, hook_cfg):
+        """Register a hook from its cfg.
+
+        Args:
+            hook_cfg (dict): Hook config. It should have at least keys 'type'
+              and 'priority' indicating its type and priority.
+
+        Notes:
+            The specific hook class to register should not use 'type' and
+            'priority' arguments during initialization.
+        """
+        hook_cfg = hook_cfg.copy()
+        priority = hook_cfg.pop('priority', 'NORMAL')
+        hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
+        self.register_hook(hook, priority=priority)
+
+    def call_hook(self, fn_name):
+        """Call all hooks.
+
+        Args:
+            fn_name (str): The function name in each hook to be called, such as
+                "before_train_epoch".
+        """
+        for hook in self._hooks:
+            getattr(hook, fn_name)(self)
+
+    def get_hook_info(self):
+        # Get hooks info in each stage
+        stage_hook_map = {stage: [] for stage in Hook.stages}
+        for hook in self.hooks:
+            try:
+                priority = Priority(hook.priority).name
+            except ValueError:
+                priority = hook.priority
+            classname = hook.__class__.__name__
+            hook_info = f'({priority:<12}) {classname:<35}'
+            for trigger_stage in hook.get_triggered_stages():
+                stage_hook_map[trigger_stage].append(hook_info)
+
+        stage_hook_infos = []
+        for stage in Hook.stages:
+            hook_infos = stage_hook_map[stage]
+            if len(hook_infos) > 0:
+                info = f'{stage}:\n'
+                info += '\n'.join(hook_infos)
+                info += '\n -------------------- '
+                stage_hook_infos.append(info)
+        return '\n'.join(stage_hook_infos)
+
+    def load_checkpoint(self,
+                        filename,
+                        map_location='cpu',
+                        strict=False,
+                        revise_keys=[(r'^module.', '')]):
+        return load_checkpoint(
+            self.model,
+            filename,
+            map_location,
+            strict,
+            self.logger,
+            revise_keys=revise_keys)
+
+    def resume(self,
+               checkpoint,
+               resume_optimizer=True,
+               map_location='default'):
+        if map_location == 'default':
+            if torch.cuda.is_available():
+                device_id = torch.cuda.current_device()
+                checkpoint = self.load_checkpoint(
+                    checkpoint,
+                    map_location=lambda storage, loc: storage.cuda(device_id))
+            else:
+                checkpoint = self.load_checkpoint(checkpoint)
+        else:
+            checkpoint = self.load_checkpoint(
+                checkpoint, map_location=map_location)
+
+        self._epoch = checkpoint['meta']['epoch']
+        self._iter = checkpoint['meta']['iter']
+        if self.meta is None:
+            self.meta = {}
+        self.meta.setdefault('hook_msgs', {})
+        # load `last_ckpt`, `best_score`, `best_ckpt`, etc. for hook messages
+        self.meta['hook_msgs'].update(checkpoint['meta'].get('hook_msgs', {}))
+
+        # Re-calculate the number of iterations when resuming
+        # models with different number of GPUs
+        if 'config' in checkpoint['meta']:
+            config = mmcv.Config.fromstring(
+                checkpoint['meta']['config'], file_format='.py')
+            previous_gpu_ids = config.get('gpu_ids', None)
+            if previous_gpu_ids and len(previous_gpu_ids) > 0 and len(
+                    previous_gpu_ids) != self.world_size:
+                self._iter = int(self._iter * len(previous_gpu_ids) /
+                                 self.world_size)
+                self.logger.info('the iteration number is changed due to '
+                                 'change of GPU number')
+
+        # resume meta information meta
+        self.meta = checkpoint['meta']
+
+        if 'optimizer' in checkpoint and resume_optimizer:
+            if isinstance(self.optimizer, Optimizer):
+                self.optimizer.load_state_dict(checkpoint['optimizer'])
+            elif isinstance(self.optimizer, dict):
+                for k in self.optimizer.keys():
+                    self.optimizer[k].load_state_dict(
+                        checkpoint['optimizer'][k])
+            else:
+                raise TypeError(
+                    'Optimizer should be dict or torch.optim.Optimizer '
+                    f'but got {type(self.optimizer)}')
+
+        self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
+
+    def register_lr_hook(self, lr_config):
+        if lr_config is None:
+            return
+        elif isinstance(lr_config, dict):
+            assert 'policy' in lr_config
+            policy_type = lr_config.pop('policy')
+            # If the type of policy is all in lower case, e.g., 'cyclic',
+            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+            # This is for the convenient usage of Lr updater.
+            # Since this is not applicable for `
+            # CosineAnnealingLrUpdater`,
+            # the string will not be changed if it contains capital letters.
+            if policy_type == policy_type.lower():
+                policy_type = policy_type.title()
+            hook_type = policy_type + 'LrUpdaterHook'
+            lr_config['type'] = hook_type
+            hook = mmcv.build_from_cfg(lr_config, HOOKS)
+        else:
+            hook = lr_config
+        self.register_hook(hook, priority='VERY_HIGH')
+
+    def register_momentum_hook(self, momentum_config):
+        if momentum_config is None:
+            return
+        if isinstance(momentum_config, dict):
+            assert 'policy' in momentum_config
+            policy_type = momentum_config.pop('policy')
+            # If the type of policy is all in lower case, e.g., 'cyclic',
+            # then its first letter will be capitalized, e.g., to be 'Cyclic'.
+            # This is for the convenient usage of momentum updater.
+            # Since this is not applicable for
+            # `CosineAnnealingMomentumUpdater`,
+            # the string will not be changed if it contains capital letters.
+            if policy_type == policy_type.lower():
+                policy_type = policy_type.title()
+            hook_type = policy_type + 'MomentumUpdaterHook'
+            momentum_config['type'] = hook_type
+            hook = mmcv.build_from_cfg(momentum_config, HOOKS)
+        else:
+            hook = momentum_config
+        self.register_hook(hook, priority='HIGH')
+
+    def register_optimizer_hook(self, optimizer_config):
+        if optimizer_config is None:
+            return
+        if isinstance(optimizer_config, dict):
+            optimizer_config.setdefault('type', 'OptimizerHook')
+            hook = mmcv.build_from_cfg(optimizer_config, HOOKS)
+        else:
+            hook = optimizer_config
+        self.register_hook(hook, priority='ABOVE_NORMAL')
+
+    def register_checkpoint_hook(self, checkpoint_config):
+        if checkpoint_config is None:
+            return
+        if isinstance(checkpoint_config, dict):
+            checkpoint_config.setdefault('type', 'CheckpointHook')
+            hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
+        else:
+            hook = checkpoint_config
+        self.register_hook(hook, priority='NORMAL')
+
+    def register_logger_hooks(self, log_config):
+        if log_config is None:
+            return
+        log_interval = log_config['interval']
+        for info in log_config['hooks']:
+            logger_hook = mmcv.build_from_cfg(
+                info, HOOKS, default_args=dict(interval=log_interval))
+            self.register_hook(logger_hook, priority='VERY_LOW')
+
+    def register_timer_hook(self, timer_config):
+        if timer_config is None:
+            return
+        if isinstance(timer_config, dict):
+            timer_config_ = copy.deepcopy(timer_config)
+            hook = mmcv.build_from_cfg(timer_config_, HOOKS)
+        else:
+            hook = timer_config
+        self.register_hook(hook, priority='LOW')
+
+    def register_custom_hooks(self, custom_config):
+        if custom_config is None:
+            return
+
+        if not isinstance(custom_config, list):
+            custom_config = [custom_config]
+
+        for item in custom_config:
+            if isinstance(item, dict):
+                self.register_hook_from_cfg(item)
+            else:
+                self.register_hook(item, priority='NORMAL')
+
+    def register_profiler_hook(self, profiler_config):
+        if profiler_config is None:
+            return
+        if isinstance(profiler_config, dict):
+            profiler_config.setdefault('type', 'ProfilerHook')
+            hook = mmcv.build_from_cfg(profiler_config, HOOKS)
+        else:
+            hook = profiler_config
+        self.register_hook(hook)
+
+    def register_training_hooks(self,
+                                lr_config,
+                                optimizer_config=None,
+                                checkpoint_config=None,
+                                log_config=None,
+                                momentum_config=None,
+                                timer_config=dict(type='IterTimerHook'),
+                                custom_hooks_config=None):
+        """Register default and custom hooks for training.
+
+        Default and custom hooks include:
+
+        +----------------------+-------------------------+
+        | Hooks                | Priority                |
+        +======================+=========================+
+        | LrUpdaterHook        | VERY_HIGH (10)          |
+        +----------------------+-------------------------+
+        | MomentumUpdaterHook  | HIGH (30)               |
+        +----------------------+-------------------------+
+        | OptimizerStepperHook | ABOVE_NORMAL (40)       |
+        +----------------------+-------------------------+
+        | CheckpointSaverHook  | NORMAL (50)             |
+        +----------------------+-------------------------+
+        | IterTimerHook        | LOW (70)                |
+        +----------------------+-------------------------+
+        | LoggerHook(s)        | VERY_LOW (90)           |
+        +----------------------+-------------------------+
+        | CustomHook(s)        | defaults to NORMAL (50) |
+        +----------------------+-------------------------+
+
+        If custom hooks have same priority with default hooks, custom hooks
+        will be triggered after default hooks.
+        """
+        self.register_lr_hook(lr_config)
+        self.register_momentum_hook(momentum_config)
+        self.register_optimizer_hook(optimizer_config)
+        self.register_checkpoint_hook(checkpoint_config)
+        self.register_timer_hook(timer_config)
+        self.register_logger_hooks(log_config)
+        self.register_custom_hooks(custom_hooks_config)
diff --git a/annotator/uniformer/mmcv/runner/builder.py b/annotator/uniformer/mmcv/runner/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..77c96ba0b2f30ead9da23f293c5dc84dd3e4a74f
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/builder.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+
+from ..utils import Registry
+
+RUNNERS = Registry('runner')
+RUNNER_BUILDERS = Registry('runner builder')
+
+
+def build_runner_constructor(cfg):
+    return RUNNER_BUILDERS.build(cfg)
+
+
+def build_runner(cfg, default_args=None):
+    runner_cfg = copy.deepcopy(cfg)
+    constructor_type = runner_cfg.pop('constructor',
+                                      'DefaultRunnerConstructor')
+    runner_constructor = build_runner_constructor(
+        dict(
+            type=constructor_type,
+            runner_cfg=runner_cfg,
+            default_args=default_args))
+    runner = runner_constructor()
+    return runner
diff --git a/annotator/uniformer/mmcv/runner/checkpoint.py b/annotator/uniformer/mmcv/runner/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b29ca320679164432f446adad893e33fb2b4b29e
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/checkpoint.py
@@ -0,0 +1,707 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import re
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+
+import annotator.uniformer.mmcv as mmcv
+from ..fileio import FileClient
+from ..fileio import load as load_file
+from ..parallel import is_module_wrapper
+from ..utils import mkdir_or_exist
+from .dist_utils import get_dist_info
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+    mmcv_home = os.path.expanduser(
+        os.getenv(
+            ENV_MMCV_HOME,
+            os.path.join(
+                os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+    mkdir_or_exist(mmcv_home)
+    return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+    """Load state_dict to a module.
+
+    This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+    Default value for ``strict`` is set to ``False`` and the message for
+    param mismatch will be shown even if strict is False.
+
+    Args:
+        module (Module): Module that receives the state_dict.
+        state_dict (OrderedDict): Weights.
+        strict (bool): whether to strictly enforce that the keys
+            in :attr:`state_dict` match the keys returned by this module's
+            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+        logger (:obj:`logging.Logger`, optional): Logger to log the error
+            message. If not specified, print function will be used.
+    """
+    unexpected_keys = []
+    all_missing_keys = []
+    err_msg = []
+
+    metadata = getattr(state_dict, '_metadata', None)
+    state_dict = state_dict.copy()
+    if metadata is not None:
+        state_dict._metadata = metadata
+
+    # use _load_from_state_dict to enable checkpoint version control
+    def load(module, prefix=''):
+        # recursively check parallel module in case that the model has a
+        # complicated structure, e.g., nn.Module(nn.Module(DDP))
+        if is_module_wrapper(module):
+            module = module.module
+        local_metadata = {} if metadata is None else metadata.get(
+            prefix[:-1], {})
+        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+                                     all_missing_keys, unexpected_keys,
+                                     err_msg)
+        for name, child in module._modules.items():
+            if child is not None:
+                load(child, prefix + name + '.')
+
+    load(module)
+    load = None  # break load->load reference cycle
+
+    # ignore "num_batches_tracked" of BN layers
+    missing_keys = [
+        key for key in all_missing_keys if 'num_batches_tracked' not in key
+    ]
+
+    if unexpected_keys:
+        err_msg.append('unexpected key in source '
+                       f'state_dict: {", ".join(unexpected_keys)}\n')
+    if missing_keys:
+        err_msg.append(
+            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+    rank, _ = get_dist_info()
+    if len(err_msg) > 0 and rank == 0:
+        err_msg.insert(
+            0, 'The model and loaded state dict do not match exactly\n')
+        err_msg = '\n'.join(err_msg)
+        if strict:
+            raise RuntimeError(err_msg)
+        elif logger is not None:
+            logger.warning(err_msg)
+        else:
+            print(err_msg)
+
+
+def get_torchvision_models():
+    model_urls = dict()
+    for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+        if ispkg:
+            continue
+        _zoo = import_module(f'torchvision.models.{name}')
+        if hasattr(_zoo, 'model_urls'):
+            _urls = getattr(_zoo, 'model_urls')
+            model_urls.update(_urls)
+    return model_urls
+
+
+def get_external_models():
+    mmcv_home = _get_mmcv_home()
+    default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+    default_urls = load_file(default_json_path)
+    assert isinstance(default_urls, dict)
+    external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+    if osp.exists(external_json_path):
+        external_urls = load_file(external_json_path)
+        assert isinstance(external_urls, dict)
+        default_urls.update(external_urls)
+
+    return default_urls
+
+
+def get_mmcls_models():
+    mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+    mmcls_urls = load_file(mmcls_json_path)
+
+    return mmcls_urls
+
+
+def get_deprecated_model_names():
+    deprecate_json_path = osp.join(mmcv.__path__[0],
+                                   'model_zoo/deprecated.json')
+    deprecate_urls = load_file(deprecate_json_path)
+    assert isinstance(deprecate_urls, dict)
+
+    return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+    state_dict = checkpoint['state_dict']
+    new_state_dict = OrderedDict()
+    for k, v in state_dict.items():
+        if k.startswith('backbone.'):
+            new_state_dict[k[9:]] = v
+    new_checkpoint = dict(state_dict=new_state_dict)
+
+    return new_checkpoint
+
+
+class CheckpointLoader:
+    """A general checkpoint loader to manage all schemes."""
+
+    _schemes = {}
+
+    @classmethod
+    def _register_scheme(cls, prefixes, loader, force=False):
+        if isinstance(prefixes, str):
+            prefixes = [prefixes]
+        else:
+            assert isinstance(prefixes, (list, tuple))
+        for prefix in prefixes:
+            if (prefix not in cls._schemes) or force:
+                cls._schemes[prefix] = loader
+            else:
+                raise KeyError(
+                    f'{prefix} is already registered as a loader backend, '
+                    'add "force=True" if you want to override it')
+        # sort, longer prefixes take priority
+        cls._schemes = OrderedDict(
+            sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True))
+
+    @classmethod
+    def register_scheme(cls, prefixes, loader=None, force=False):
+        """Register a loader to CheckpointLoader.
+
+        This method can be used as a normal class method or a decorator.
+
+        Args:
+            prefixes (str or list[str] or tuple[str]):
+            The prefix of the registered loader.
+            loader (function, optional): The loader function to be registered.
+                When this method is used as a decorator, loader is None.
+                Defaults to None.
+            force (bool, optional): Whether to override the loader
+                if the prefix has already been registered. Defaults to False.
+        """
+
+        if loader is not None:
+            cls._register_scheme(prefixes, loader, force=force)
+            return
+
+        def _register(loader_cls):
+            cls._register_scheme(prefixes, loader_cls, force=force)
+            return loader_cls
+
+        return _register
+
+    @classmethod
+    def _get_checkpoint_loader(cls, path):
+        """Finds a loader that supports the given path. Falls back to the local
+        loader if no other loader is found.
+
+        Args:
+            path (str): checkpoint path
+
+        Returns:
+            loader (function): checkpoint loader
+        """
+
+        for p in cls._schemes:
+            if path.startswith(p):
+                return cls._schemes[p]
+
+    @classmethod
+    def load_checkpoint(cls, filename, map_location=None, logger=None):
+        """load checkpoint through URL scheme path.
+
+        Args:
+            filename (str): checkpoint file name with given prefix
+            map_location (str, optional): Same as :func:`torch.load`.
+                Default: None
+            logger (:mod:`logging.Logger`, optional): The logger for message.
+                Default: None
+
+        Returns:
+            dict or OrderedDict: The loaded checkpoint.
+        """
+
+        checkpoint_loader = cls._get_checkpoint_loader(filename)
+        class_name = checkpoint_loader.__name__
+        mmcv.print_log(
+            f'load checkpoint from {class_name[10:]} path: {filename}', logger)
+        return checkpoint_loader(filename, map_location)
+
+
+@CheckpointLoader.register_scheme(prefixes='')
+def load_from_local(filename, map_location):
+    """load checkpoint by local file path.
+
+    Args:
+        filename (str): local checkpoint file path
+        map_location (str, optional): Same as :func:`torch.load`.
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+
+    if not osp.isfile(filename):
+        raise IOError(f'{filename} is not a checkpoint file')
+    checkpoint = torch.load(filename, map_location=map_location)
+    return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes=('http://', 'https://'))
+def load_from_http(filename, map_location=None, model_dir=None):
+    """load checkpoint through HTTP or HTTPS scheme path. In distributed
+    setting, this function only download checkpoint at local rank 0.
+
+    Args:
+        filename (str): checkpoint file path with modelzoo or
+            torchvision prefix
+        map_location (str, optional): Same as :func:`torch.load`.
+        model_dir (string, optional): directory in which to save the object,
+            Default: None
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    if rank == 0:
+        checkpoint = model_zoo.load_url(
+            filename, model_dir=model_dir, map_location=map_location)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            checkpoint = model_zoo.load_url(
+                filename, model_dir=model_dir, map_location=map_location)
+    return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='pavi://')
+def load_from_pavi(filename, map_location=None):
+    """load checkpoint through the file path prefixed with pavi. In distributed
+    setting, this function download ckpt at all ranks to different temporary
+    directories.
+
+    Args:
+        filename (str): checkpoint file path with pavi prefix
+        map_location (str, optional): Same as :func:`torch.load`.
+          Default: None
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+    assert filename.startswith('pavi://'), \
+        f'Expected filename startswith `pavi://`, but get {filename}'
+    model_path = filename[7:]
+
+    try:
+        from pavi import modelcloud
+    except ImportError:
+        raise ImportError(
+            'Please install pavi to load checkpoint from modelcloud.')
+
+    model = modelcloud.get(model_path)
+    with TemporaryDirectory() as tmp_dir:
+        downloaded_file = osp.join(tmp_dir, model.name)
+        model.download(downloaded_file)
+        checkpoint = torch.load(downloaded_file, map_location=map_location)
+    return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='s3://')
+def load_from_ceph(filename, map_location=None, backend='petrel'):
+    """load checkpoint through the file path prefixed with s3.  In distributed
+    setting, this function download ckpt at all ranks to different temporary
+    directories.
+
+    Args:
+        filename (str): checkpoint file path with s3 prefix
+        map_location (str, optional): Same as :func:`torch.load`.
+        backend (str, optional): The storage backend type. Options are 'ceph',
+            'petrel'. Default: 'petrel'.
+
+    .. warning::
+        :class:`mmcv.fileio.file_client.CephBackend` will be deprecated,
+        please use :class:`mmcv.fileio.file_client.PetrelBackend` instead.
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+    allowed_backends = ['ceph', 'petrel']
+    if backend not in allowed_backends:
+        raise ValueError(f'Load from Backend {backend} is not supported.')
+
+    if backend == 'ceph':
+        warnings.warn(
+            'CephBackend will be deprecated, please use PetrelBackend instead')
+
+    # CephClient and PetrelBackend have the same prefix 's3://' and the latter
+    # will be chosen as default. If PetrelBackend can not be instantiated
+    # successfully, the CephClient will be chosen.
+    try:
+        file_client = FileClient(backend=backend)
+    except ImportError:
+        allowed_backends.remove(backend)
+        file_client = FileClient(backend=allowed_backends[0])
+
+    with io.BytesIO(file_client.get(filename)) as buffer:
+        checkpoint = torch.load(buffer, map_location=map_location)
+    return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://'))
+def load_from_torchvision(filename, map_location=None):
+    """load checkpoint through the file path prefixed with modelzoo or
+    torchvision.
+
+    Args:
+        filename (str): checkpoint file path with modelzoo or
+            torchvision prefix
+        map_location (str, optional): Same as :func:`torch.load`.
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+    model_urls = get_torchvision_models()
+    if filename.startswith('modelzoo://'):
+        warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+                      'use "torchvision://" instead')
+        model_name = filename[11:]
+    else:
+        model_name = filename[14:]
+    return load_from_http(model_urls[model_name], map_location=map_location)
+
+
+@CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://'))
+def load_from_openmmlab(filename, map_location=None):
+    """load checkpoint through the file path prefixed with open-mmlab or
+    openmmlab.
+
+    Args:
+        filename (str): checkpoint file path with open-mmlab or
+        openmmlab prefix
+        map_location (str, optional): Same as :func:`torch.load`.
+          Default: None
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+
+    model_urls = get_external_models()
+    prefix_str = 'open-mmlab://'
+    if filename.startswith(prefix_str):
+        model_name = filename[13:]
+    else:
+        model_name = filename[12:]
+        prefix_str = 'openmmlab://'
+
+    deprecated_urls = get_deprecated_model_names()
+    if model_name in deprecated_urls:
+        warnings.warn(f'{prefix_str}{model_name} is deprecated in favor '
+                      f'of {prefix_str}{deprecated_urls[model_name]}')
+        model_name = deprecated_urls[model_name]
+    model_url = model_urls[model_name]
+    # check if is url
+    if model_url.startswith(('http://', 'https://')):
+        checkpoint = load_from_http(model_url, map_location=map_location)
+    else:
+        filename = osp.join(_get_mmcv_home(), model_url)
+        if not osp.isfile(filename):
+            raise IOError(f'{filename} is not a checkpoint file')
+        checkpoint = torch.load(filename, map_location=map_location)
+    return checkpoint
+
+
+@CheckpointLoader.register_scheme(prefixes='mmcls://')
+def load_from_mmcls(filename, map_location=None):
+    """load checkpoint through the file path prefixed with mmcls.
+
+    Args:
+        filename (str): checkpoint file path with mmcls prefix
+        map_location (str, optional): Same as :func:`torch.load`.
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+
+    model_urls = get_mmcls_models()
+    model_name = filename[8:]
+    checkpoint = load_from_http(
+        model_urls[model_name], map_location=map_location)
+    checkpoint = _process_mmcls_checkpoint(checkpoint)
+    return checkpoint
+
+
+def _load_checkpoint(filename, map_location=None, logger=None):
+    """Load checkpoint from somewhere (modelzoo, file, url).
+
+    Args:
+        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+            details.
+        map_location (str, optional): Same as :func:`torch.load`.
+           Default: None.
+        logger (:mod:`logging.Logger`, optional): The logger for error message.
+           Default: None
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint. It can be either an
+           OrderedDict storing model weights or a dict containing other
+           information, which depends on the checkpoint.
+    """
+    return CheckpointLoader.load_checkpoint(filename, map_location, logger)
+
+
+def _load_checkpoint_with_prefix(prefix, filename, map_location=None):
+    """Load partial pretrained model with specific prefix.
+
+    Args:
+        prefix (str): The prefix of sub-module.
+        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+            details.
+        map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+
+    checkpoint = _load_checkpoint(filename, map_location=map_location)
+
+    if 'state_dict' in checkpoint:
+        state_dict = checkpoint['state_dict']
+    else:
+        state_dict = checkpoint
+    if not prefix.endswith('.'):
+        prefix += '.'
+    prefix_len = len(prefix)
+
+    state_dict = {
+        k[prefix_len:]: v
+        for k, v in state_dict.items() if k.startswith(prefix)
+    }
+
+    assert state_dict, f'{prefix} is not in the pretrained model'
+    return state_dict
+
+
+def load_checkpoint(model,
+                    filename,
+                    map_location=None,
+                    strict=False,
+                    logger=None,
+                    revise_keys=[(r'^module\.', '')]):
+    """Load checkpoint from a file or URI.
+
+    Args:
+        model (Module): Module to load checkpoint.
+        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+            details.
+        map_location (str): Same as :func:`torch.load`.
+        strict (bool): Whether to allow different params for the model and
+            checkpoint.
+        logger (:mod:`logging.Logger` or None): The logger for error message.
+        revise_keys (list): A list of customized keywords to modify the
+            state_dict in checkpoint. Each item is a (pattern, replacement)
+            pair of the regular expression operations. Default: strip
+            the prefix 'module.' by [(r'^module\\.', '')].
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+    checkpoint = _load_checkpoint(filename, map_location, logger)
+    # OrderedDict is a subclass of dict
+    if not isinstance(checkpoint, dict):
+        raise RuntimeError(
+            f'No state_dict found in checkpoint file {filename}')
+    # get state_dict from checkpoint
+    if 'state_dict' in checkpoint:
+        state_dict = checkpoint['state_dict']
+    else:
+        state_dict = checkpoint
+
+    # strip prefix of state_dict
+    metadata = getattr(state_dict, '_metadata', OrderedDict())
+    for p, r in revise_keys:
+        state_dict = OrderedDict(
+            {re.sub(p, r, k): v
+             for k, v in state_dict.items()})
+    # Keep metadata in state_dict
+    state_dict._metadata = metadata
+
+    # load state_dict
+    load_state_dict(model, state_dict, strict, logger)
+    return checkpoint
+
+
+def weights_to_cpu(state_dict):
+    """Copy a model state_dict to cpu.
+
+    Args:
+        state_dict (OrderedDict): Model weights on GPU.
+
+    Returns:
+        OrderedDict: Model weights on GPU.
+    """
+    state_dict_cpu = OrderedDict()
+    for key, val in state_dict.items():
+        state_dict_cpu[key] = val.cpu()
+    # Keep metadata in state_dict
+    state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
+    return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+    """Saves module state to `destination` dictionary.
+
+    This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+    Args:
+        module (nn.Module): The module to generate state_dict.
+        destination (dict): A dict where state will be stored.
+        prefix (str): The prefix for parameters and buffers used in this
+            module.
+    """
+    for name, param in module._parameters.items():
+        if param is not None:
+            destination[prefix + name] = param if keep_vars else param.detach()
+    for name, buf in module._buffers.items():
+        # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+        if buf is not None:
+            destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+    """Returns a dictionary containing a whole state of the module.
+
+    Both parameters and persistent buffers (e.g. running averages) are
+    included. Keys are corresponding parameter and buffer names.
+
+    This method is modified from :meth:`torch.nn.Module.state_dict` to
+    recursively check parallel module in case that the model has a complicated
+    structure, e.g., nn.Module(nn.Module(DDP)).
+
+    Args:
+        module (nn.Module): The module to generate state_dict.
+        destination (OrderedDict): Returned dict for the state of the
+            module.
+        prefix (str): Prefix of the key.
+        keep_vars (bool): Whether to keep the variable property of the
+            parameters. Default: False.
+
+    Returns:
+        dict: A dictionary containing a whole state of the module.
+    """
+    # recursively check parallel module in case that the model has a
+    # complicated structure, e.g., nn.Module(nn.Module(DDP))
+    if is_module_wrapper(module):
+        module = module.module
+
+    # below is the same as torch.nn.Module.state_dict()
+    if destination is None:
+        destination = OrderedDict()
+        destination._metadata = OrderedDict()
+    destination._metadata[prefix[:-1]] = local_metadata = dict(
+        version=module._version)
+    _save_to_state_dict(module, destination, prefix, keep_vars)
+    for name, child in module._modules.items():
+        if child is not None:
+            get_state_dict(
+                child, destination, prefix + name + '.', keep_vars=keep_vars)
+    for hook in module._state_dict_hooks.values():
+        hook_result = hook(module, destination, prefix, local_metadata)
+        if hook_result is not None:
+            destination = hook_result
+    return destination
+
+
+def save_checkpoint(model,
+                    filename,
+                    optimizer=None,
+                    meta=None,
+                    file_client_args=None):
+    """Save checkpoint to file.
+
+    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+    ``optimizer``. By default ``meta`` will contain version and time info.
+
+    Args:
+        model (Module): Module whose params are to be saved.
+        filename (str): Checkpoint filename.
+        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+        meta (dict, optional): Metadata to be saved in checkpoint.
+        file_client_args (dict, optional): Arguments to instantiate a
+            FileClient. See :class:`mmcv.fileio.FileClient` for details.
+            Default: None.
+            `New in version 1.3.16.`
+    """
+    if meta is None:
+        meta = {}
+    elif not isinstance(meta, dict):
+        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+    if is_module_wrapper(model):
+        model = model.module
+
+    if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+        # save class name to the meta
+        meta.update(CLASSES=model.CLASSES)
+
+    checkpoint = {
+        'meta': meta,
+        'state_dict': weights_to_cpu(get_state_dict(model))
+    }
+    # save optimizer state dict in the checkpoint
+    if isinstance(optimizer, Optimizer):
+        checkpoint['optimizer'] = optimizer.state_dict()
+    elif isinstance(optimizer, dict):
+        checkpoint['optimizer'] = {}
+        for name, optim in optimizer.items():
+            checkpoint['optimizer'][name] = optim.state_dict()
+
+    if filename.startswith('pavi://'):
+        if file_client_args is not None:
+            raise ValueError(
+                'file_client_args should be "None" if filename starts with'
+                f'"pavi://", but got {file_client_args}')
+        try:
+            from pavi import modelcloud
+            from pavi import exception
+        except ImportError:
+            raise ImportError(
+                'Please install pavi to load checkpoint from modelcloud.')
+        model_path = filename[7:]
+        root = modelcloud.Folder()
+        model_dir, model_name = osp.split(model_path)
+        try:
+            model = modelcloud.get(model_dir)
+        except exception.NodeNotFoundError:
+            model = root.create_training_model(model_dir)
+        with TemporaryDirectory() as tmp_dir:
+            checkpoint_file = osp.join(tmp_dir, model_name)
+            with open(checkpoint_file, 'wb') as f:
+                torch.save(checkpoint, f)
+                f.flush()
+            model.create_file(checkpoint_file, name=model_name)
+    else:
+        file_client = FileClient.infer_client(file_client_args, filename)
+        with io.BytesIO() as f:
+            torch.save(checkpoint, f)
+            file_client.put(f.getvalue(), filename)
diff --git a/annotator/uniformer/mmcv/runner/default_constructor.py b/annotator/uniformer/mmcv/runner/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f1f5b44168768dfda3947393a63a6cf9cf50b41
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/default_constructor.py
@@ -0,0 +1,44 @@
+from .builder import RUNNER_BUILDERS, RUNNERS
+
+
+@RUNNER_BUILDERS.register_module()
+class DefaultRunnerConstructor:
+    """Default constructor for runners.
+
+    Custom existing `Runner` like `EpocBasedRunner` though `RunnerConstructor`.
+    For example, We can inject some new properties and functions for `Runner`.
+
+    Example:
+        >>> from annotator.uniformer.mmcv.runner import RUNNER_BUILDERS, build_runner
+        >>> # Define a new RunnerReconstructor
+        >>> @RUNNER_BUILDERS.register_module()
+        >>> class MyRunnerConstructor:
+        ...     def __init__(self, runner_cfg, default_args=None):
+        ...         if not isinstance(runner_cfg, dict):
+        ...             raise TypeError('runner_cfg should be a dict',
+        ...                             f'but got {type(runner_cfg)}')
+        ...         self.runner_cfg = runner_cfg
+        ...         self.default_args = default_args
+        ...
+        ...     def __call__(self):
+        ...         runner = RUNNERS.build(self.runner_cfg,
+        ...                                default_args=self.default_args)
+        ...         # Add new properties for existing runner
+        ...         runner.my_name = 'my_runner'
+        ...         runner.my_function = lambda self: print(self.my_name)
+        ...         ...
+        >>> # build your runner
+        >>> runner_cfg = dict(type='EpochBasedRunner', max_epochs=40,
+        ...                   constructor='MyRunnerConstructor')
+        >>> runner = build_runner(runner_cfg)
+    """
+
+    def __init__(self, runner_cfg, default_args=None):
+        if not isinstance(runner_cfg, dict):
+            raise TypeError('runner_cfg should be a dict',
+                            f'but got {type(runner_cfg)}')
+        self.runner_cfg = runner_cfg
+        self.default_args = default_args
+
+    def __call__(self):
+        return RUNNERS.build(self.runner_cfg, default_args=self.default_args)
diff --git a/annotator/uniformer/mmcv/runner/dist_utils.py b/annotator/uniformer/mmcv/runner/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3a1ef3fda5ceeb31bf15a73779da1b1903ab0fe
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/dist_utils.py
@@ -0,0 +1,164 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import os
+import subprocess
+from collections import OrderedDict
+
+import torch
+import torch.multiprocessing as mp
+from torch import distributed as dist
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+                          _unflatten_dense_tensors)
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+    if mp.get_start_method(allow_none=True) is None:
+        mp.set_start_method('spawn')
+    if launcher == 'pytorch':
+        _init_dist_pytorch(backend, **kwargs)
+    elif launcher == 'mpi':
+        _init_dist_mpi(backend, **kwargs)
+    elif launcher == 'slurm':
+        _init_dist_slurm(backend, **kwargs)
+    else:
+        raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+    # TODO: use local_rank instead of rank % num_gpus
+    rank = int(os.environ['RANK'])
+    num_gpus = torch.cuda.device_count()
+    torch.cuda.set_device(rank % num_gpus)
+    dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_mpi(backend, **kwargs):
+    # TODO: use local_rank instead of rank % num_gpus
+    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+    num_gpus = torch.cuda.device_count()
+    torch.cuda.set_device(rank % num_gpus)
+    dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+    """Initialize slurm distributed training environment.
+
+    If argument ``port`` is not specified, then the master port will be system
+    environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+    environment variable, then a default port ``29500`` will be used.
+
+    Args:
+        backend (str): Backend of torch.distributed.
+        port (int, optional): Master port. Defaults to None.
+    """
+    proc_id = int(os.environ['SLURM_PROCID'])
+    ntasks = int(os.environ['SLURM_NTASKS'])
+    node_list = os.environ['SLURM_NODELIST']
+    num_gpus = torch.cuda.device_count()
+    torch.cuda.set_device(proc_id % num_gpus)
+    addr = subprocess.getoutput(
+        f'scontrol show hostname {node_list} | head -n1')
+    # specify master port
+    if port is not None:
+        os.environ['MASTER_PORT'] = str(port)
+    elif 'MASTER_PORT' in os.environ:
+        pass  # use MASTER_PORT in the environment variable
+    else:
+        # 29500 is torch.distributed default port
+        os.environ['MASTER_PORT'] = '29500'
+    # use MASTER_ADDR in the environment variable if it already exists
+    if 'MASTER_ADDR' not in os.environ:
+        os.environ['MASTER_ADDR'] = addr
+    os.environ['WORLD_SIZE'] = str(ntasks)
+    os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+    os.environ['RANK'] = str(proc_id)
+    dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+    if dist.is_available() and dist.is_initialized():
+        rank = dist.get_rank()
+        world_size = dist.get_world_size()
+    else:
+        rank = 0
+        world_size = 1
+    return rank, world_size
+
+
+def master_only(func):
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        rank, _ = get_dist_info()
+        if rank == 0:
+            return func(*args, **kwargs)
+
+    return wrapper
+
+
+def allreduce_params(params, coalesce=True, bucket_size_mb=-1):
+    """Allreduce parameters.
+
+    Args:
+        params (list[torch.Parameters]): List of parameters or buffers of a
+            model.
+        coalesce (bool, optional): Whether allreduce parameters as a whole.
+            Defaults to True.
+        bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+            Defaults to -1.
+    """
+    _, world_size = get_dist_info()
+    if world_size == 1:
+        return
+    params = [param.data for param in params]
+    if coalesce:
+        _allreduce_coalesced(params, world_size, bucket_size_mb)
+    else:
+        for tensor in params:
+            dist.all_reduce(tensor.div_(world_size))
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+    """Allreduce gradients.
+
+    Args:
+        params (list[torch.Parameters]): List of parameters of a model
+        coalesce (bool, optional): Whether allreduce parameters as a whole.
+            Defaults to True.
+        bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+            Defaults to -1.
+    """
+    grads = [
+        param.grad.data for param in params
+        if param.requires_grad and param.grad is not None
+    ]
+    _, world_size = get_dist_info()
+    if world_size == 1:
+        return
+    if coalesce:
+        _allreduce_coalesced(grads, world_size, bucket_size_mb)
+    else:
+        for tensor in grads:
+            dist.all_reduce(tensor.div_(world_size))
+
+
+def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
+    if bucket_size_mb > 0:
+        bucket_size_bytes = bucket_size_mb * 1024 * 1024
+        buckets = _take_tensors(tensors, bucket_size_bytes)
+    else:
+        buckets = OrderedDict()
+        for tensor in tensors:
+            tp = tensor.type()
+            if tp not in buckets:
+                buckets[tp] = []
+            buckets[tp].append(tensor)
+        buckets = buckets.values()
+
+    for bucket in buckets:
+        flat_tensors = _flatten_dense_tensors(bucket)
+        dist.all_reduce(flat_tensors)
+        flat_tensors.div_(world_size)
+        for tensor, synced in zip(
+                bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
+            tensor.copy_(synced)
diff --git a/annotator/uniformer/mmcv/runner/epoch_based_runner.py b/annotator/uniformer/mmcv/runner/epoch_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..766a9ce6afdf09cd11b1b15005f5132583011348
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/epoch_based_runner.py
@@ -0,0 +1,187 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+
+import torch
+
+import annotator.uniformer.mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .utils import get_host_info
+
+
+@RUNNERS.register_module()
+class EpochBasedRunner(BaseRunner):
+    """Epoch-based Runner.
+
+    This runner train models epoch by epoch.
+    """
+
+    def run_iter(self, data_batch, train_mode, **kwargs):
+        if self.batch_processor is not None:
+            outputs = self.batch_processor(
+                self.model, data_batch, train_mode=train_mode, **kwargs)
+        elif train_mode:
+            outputs = self.model.train_step(data_batch, self.optimizer,
+                                            **kwargs)
+        else:
+            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
+        if not isinstance(outputs, dict):
+            raise TypeError('"batch_processor()" or "model.train_step()"'
+                            'and "model.val_step()" must return a dict')
+        if 'log_vars' in outputs:
+            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+        self.outputs = outputs
+
+    def train(self, data_loader, **kwargs):
+        self.model.train()
+        self.mode = 'train'
+        self.data_loader = data_loader
+        self._max_iters = self._max_epochs * len(self.data_loader)
+        self.call_hook('before_train_epoch')
+        time.sleep(2)  # Prevent possible deadlock during epoch transition
+        for i, data_batch in enumerate(self.data_loader):
+            self._inner_iter = i
+            self.call_hook('before_train_iter')
+            self.run_iter(data_batch, train_mode=True, **kwargs)
+            self.call_hook('after_train_iter')
+            self._iter += 1
+
+        self.call_hook('after_train_epoch')
+        self._epoch += 1
+
+    @torch.no_grad()
+    def val(self, data_loader, **kwargs):
+        self.model.eval()
+        self.mode = 'val'
+        self.data_loader = data_loader
+        self.call_hook('before_val_epoch')
+        time.sleep(2)  # Prevent possible deadlock during epoch transition
+        for i, data_batch in enumerate(self.data_loader):
+            self._inner_iter = i
+            self.call_hook('before_val_iter')
+            self.run_iter(data_batch, train_mode=False)
+            self.call_hook('after_val_iter')
+
+        self.call_hook('after_val_epoch')
+
+    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
+        """Start running.
+
+        Args:
+            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+                and validation.
+            workflow (list[tuple]): A list of (phase, epochs) to specify the
+                running order and epochs. E.g, [('train', 2), ('val', 1)] means
+                running 2 epochs for training and 1 epoch for validation,
+                iteratively.
+        """
+        assert isinstance(data_loaders, list)
+        assert mmcv.is_list_of(workflow, tuple)
+        assert len(data_loaders) == len(workflow)
+        if max_epochs is not None:
+            warnings.warn(
+                'setting max_epochs in run is deprecated, '
+                'please set max_epochs in runner_config', DeprecationWarning)
+            self._max_epochs = max_epochs
+
+        assert self._max_epochs is not None, (
+            'max_epochs must be specified during instantiation')
+
+        for i, flow in enumerate(workflow):
+            mode, epochs = flow
+            if mode == 'train':
+                self._max_iters = self._max_epochs * len(data_loaders[i])
+                break
+
+        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+        self.logger.info('Start running, host: %s, work_dir: %s',
+                         get_host_info(), work_dir)
+        self.logger.info('Hooks will be executed in the following order:\n%s',
+                         self.get_hook_info())
+        self.logger.info('workflow: %s, max: %d epochs', workflow,
+                         self._max_epochs)
+        self.call_hook('before_run')
+
+        while self.epoch < self._max_epochs:
+            for i, flow in enumerate(workflow):
+                mode, epochs = flow
+                if isinstance(mode, str):  # self.train()
+                    if not hasattr(self, mode):
+                        raise ValueError(
+                            f'runner has no method named "{mode}" to run an '
+                            'epoch')
+                    epoch_runner = getattr(self, mode)
+                else:
+                    raise TypeError(
+                        'mode in workflow must be a str, but got {}'.format(
+                            type(mode)))
+
+                for _ in range(epochs):
+                    if mode == 'train' and self.epoch >= self._max_epochs:
+                        break
+                    epoch_runner(data_loaders[i], **kwargs)
+
+        time.sleep(1)  # wait for some hooks like loggers to finish
+        self.call_hook('after_run')
+
+    def save_checkpoint(self,
+                        out_dir,
+                        filename_tmpl='epoch_{}.pth',
+                        save_optimizer=True,
+                        meta=None,
+                        create_symlink=True):
+        """Save the checkpoint.
+
+        Args:
+            out_dir (str): The directory that checkpoints are saved.
+            filename_tmpl (str, optional): The checkpoint filename template,
+                which contains a placeholder for the epoch number.
+                Defaults to 'epoch_{}.pth'.
+            save_optimizer (bool, optional): Whether to save the optimizer to
+                the checkpoint. Defaults to True.
+            meta (dict, optional): The meta information to be saved in the
+                checkpoint. Defaults to None.
+            create_symlink (bool, optional): Whether to create a symlink
+                "latest.pth" to point to the latest checkpoint.
+                Defaults to True.
+        """
+        if meta is None:
+            meta = {}
+        elif not isinstance(meta, dict):
+            raise TypeError(
+                f'meta should be a dict or None, but got {type(meta)}')
+        if self.meta is not None:
+            meta.update(self.meta)
+            # Note: meta.update(self.meta) should be done before
+            # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+            # there will be problems with resumed checkpoints.
+            # More details in https://github.com/open-mmlab/mmcv/pull/1108
+        meta.update(epoch=self.epoch + 1, iter=self.iter)
+
+        filename = filename_tmpl.format(self.epoch + 1)
+        filepath = osp.join(out_dir, filename)
+        optimizer = self.optimizer if save_optimizer else None
+        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+        # in some environments, `os.symlink` is not supported, you may need to
+        # set `create_symlink` to False
+        if create_symlink:
+            dst_file = osp.join(out_dir, 'latest.pth')
+            if platform.system() != 'Windows':
+                mmcv.symlink(filename, dst_file)
+            else:
+                shutil.copy(filepath, dst_file)
+
+
+@RUNNERS.register_module()
+class Runner(EpochBasedRunner):
+    """Deprecated name of EpochBasedRunner."""
+
+    def __init__(self, *args, **kwargs):
+        warnings.warn(
+            'Runner was deprecated, please use EpochBasedRunner instead')
+        super().__init__(*args, **kwargs)
diff --git a/annotator/uniformer/mmcv/runner/fp16_utils.py b/annotator/uniformer/mmcv/runner/fp16_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1981011d6859192e3e663e29d13500d56ba47f6c
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/fp16_utils.py
@@ -0,0 +1,410 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import functools
+import warnings
+from collections import abc
+from inspect import getfullargspec
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from .dist_utils import allreduce_grads as _allreduce_grads
+
+try:
+    # If PyTorch version >= 1.6.0, torch.cuda.amp.autocast would be imported
+    # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+    # Note that when PyTorch >= 1.6.0, we still cast tensor types to fp16
+    # manually, so the behavior may not be consistent with real amp.
+    from torch.cuda.amp import autocast
+except ImportError:
+    pass
+
+
+def cast_tensor_type(inputs, src_type, dst_type):
+    """Recursively convert Tensor in inputs from src_type to dst_type.
+
+    Args:
+        inputs: Inputs that to be casted.
+        src_type (torch.dtype): Source type..
+        dst_type (torch.dtype): Destination type.
+
+    Returns:
+        The same type with inputs, but all contained Tensors have been cast.
+    """
+    if isinstance(inputs, nn.Module):
+        return inputs
+    elif isinstance(inputs, torch.Tensor):
+        return inputs.to(dst_type)
+    elif isinstance(inputs, str):
+        return inputs
+    elif isinstance(inputs, np.ndarray):
+        return inputs
+    elif isinstance(inputs, abc.Mapping):
+        return type(inputs)({
+            k: cast_tensor_type(v, src_type, dst_type)
+            for k, v in inputs.items()
+        })
+    elif isinstance(inputs, abc.Iterable):
+        return type(inputs)(
+            cast_tensor_type(item, src_type, dst_type) for item in inputs)
+    else:
+        return inputs
+
+
+def auto_fp16(apply_to=None, out_fp32=False):
+    """Decorator to enable fp16 training automatically.
+
+    This decorator is useful when you write custom modules and want to support
+    mixed precision training. If inputs arguments are fp32 tensors, they will
+    be converted to fp16 automatically. Arguments other than fp32 tensors are
+    ignored. If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+    backend, otherwise, original mmcv implementation will be adopted.
+
+    Args:
+        apply_to (Iterable, optional): The argument names to be converted.
+            `None` indicates all arguments.
+        out_fp32 (bool): Whether to convert the output back to fp32.
+
+    Example:
+
+        >>> import torch.nn as nn
+        >>> class MyModule1(nn.Module):
+        >>>
+        >>>     # Convert x and y to fp16
+        >>>     @auto_fp16()
+        >>>     def forward(self, x, y):
+        >>>         pass
+
+        >>> import torch.nn as nn
+        >>> class MyModule2(nn.Module):
+        >>>
+        >>>     # convert pred to fp16
+        >>>     @auto_fp16(apply_to=('pred', ))
+        >>>     def do_something(self, pred, others):
+        >>>         pass
+    """
+
+    def auto_fp16_wrapper(old_func):
+
+        @functools.wraps(old_func)
+        def new_func(*args, **kwargs):
+            # check if the module has set the attribute `fp16_enabled`, if not,
+            # just fallback to the original method.
+            if not isinstance(args[0], torch.nn.Module):
+                raise TypeError('@auto_fp16 can only be used to decorate the '
+                                'method of nn.Module')
+            if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+                return old_func(*args, **kwargs)
+
+            # get the arg spec of the decorated method
+            args_info = getfullargspec(old_func)
+            # get the argument names to be casted
+            args_to_cast = args_info.args if apply_to is None else apply_to
+            # convert the args that need to be processed
+            new_args = []
+            # NOTE: default args are not taken into consideration
+            if args:
+                arg_names = args_info.args[:len(args)]
+                for i, arg_name in enumerate(arg_names):
+                    if arg_name in args_to_cast:
+                        new_args.append(
+                            cast_tensor_type(args[i], torch.float, torch.half))
+                    else:
+                        new_args.append(args[i])
+            # convert the kwargs that need to be processed
+            new_kwargs = {}
+            if kwargs:
+                for arg_name, arg_value in kwargs.items():
+                    if arg_name in args_to_cast:
+                        new_kwargs[arg_name] = cast_tensor_type(
+                            arg_value, torch.float, torch.half)
+                    else:
+                        new_kwargs[arg_name] = arg_value
+            # apply converted arguments to the decorated method
+            if (TORCH_VERSION != 'parrots' and
+                    digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+                with autocast(enabled=True):
+                    output = old_func(*new_args, **new_kwargs)
+            else:
+                output = old_func(*new_args, **new_kwargs)
+            # cast the results back to fp32 if necessary
+            if out_fp32:
+                output = cast_tensor_type(output, torch.half, torch.float)
+            return output
+
+        return new_func
+
+    return auto_fp16_wrapper
+
+
+def force_fp32(apply_to=None, out_fp16=False):
+    """Decorator to convert input arguments to fp32 in force.
+
+    This decorator is useful when you write custom modules and want to support
+    mixed precision training. If there are some inputs that must be processed
+    in fp32 mode, then this decorator can handle it. If inputs arguments are
+    fp16 tensors, they will be converted to fp32 automatically. Arguments other
+    than fp16 tensors are ignored. If you are using PyTorch >= 1.6,
+    torch.cuda.amp is used as the backend, otherwise, original mmcv
+    implementation will be adopted.
+
+    Args:
+        apply_to (Iterable, optional): The argument names to be converted.
+            `None` indicates all arguments.
+        out_fp16 (bool): Whether to convert the output back to fp16.
+
+    Example:
+
+        >>> import torch.nn as nn
+        >>> class MyModule1(nn.Module):
+        >>>
+        >>>     # Convert x and y to fp32
+        >>>     @force_fp32()
+        >>>     def loss(self, x, y):
+        >>>         pass
+
+        >>> import torch.nn as nn
+        >>> class MyModule2(nn.Module):
+        >>>
+        >>>     # convert pred to fp32
+        >>>     @force_fp32(apply_to=('pred', ))
+        >>>     def post_process(self, pred, others):
+        >>>         pass
+    """
+
+    def force_fp32_wrapper(old_func):
+
+        @functools.wraps(old_func)
+        def new_func(*args, **kwargs):
+            # check if the module has set the attribute `fp16_enabled`, if not,
+            # just fallback to the original method.
+            if not isinstance(args[0], torch.nn.Module):
+                raise TypeError('@force_fp32 can only be used to decorate the '
+                                'method of nn.Module')
+            if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled):
+                return old_func(*args, **kwargs)
+            # get the arg spec of the decorated method
+            args_info = getfullargspec(old_func)
+            # get the argument names to be casted
+            args_to_cast = args_info.args if apply_to is None else apply_to
+            # convert the args that need to be processed
+            new_args = []
+            if args:
+                arg_names = args_info.args[:len(args)]
+                for i, arg_name in enumerate(arg_names):
+                    if arg_name in args_to_cast:
+                        new_args.append(
+                            cast_tensor_type(args[i], torch.half, torch.float))
+                    else:
+                        new_args.append(args[i])
+            # convert the kwargs that need to be processed
+            new_kwargs = dict()
+            if kwargs:
+                for arg_name, arg_value in kwargs.items():
+                    if arg_name in args_to_cast:
+                        new_kwargs[arg_name] = cast_tensor_type(
+                            arg_value, torch.half, torch.float)
+                    else:
+                        new_kwargs[arg_name] = arg_value
+            # apply converted arguments to the decorated method
+            if (TORCH_VERSION != 'parrots' and
+                    digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+                with autocast(enabled=False):
+                    output = old_func(*new_args, **new_kwargs)
+            else:
+                output = old_func(*new_args, **new_kwargs)
+            # cast the results back to fp32 if necessary
+            if out_fp16:
+                output = cast_tensor_type(output, torch.float, torch.half)
+            return output
+
+        return new_func
+
+    return force_fp32_wrapper
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+    warnings.warning(
+        '"mmcv.runner.fp16_utils.allreduce_grads" is deprecated, and will be '
+        'removed in v2.8. Please switch to "mmcv.runner.allreduce_grads')
+    _allreduce_grads(params, coalesce=coalesce, bucket_size_mb=bucket_size_mb)
+
+
+def wrap_fp16_model(model):
+    """Wrap the FP32 model to FP16.
+
+    If you are using PyTorch >= 1.6, torch.cuda.amp is used as the
+    backend, otherwise, original mmcv implementation will be adopted.
+
+    For PyTorch >= 1.6, this function will
+    1. Set fp16 flag inside the model to True.
+
+    Otherwise:
+    1. Convert FP32 model to FP16.
+    2. Remain some necessary layers to be FP32, e.g., normalization layers.
+    3. Set `fp16_enabled` flag inside the model to True.
+
+    Args:
+        model (nn.Module): Model in FP32.
+    """
+    if (TORCH_VERSION == 'parrots'
+            or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
+        # convert model to fp16
+        model.half()
+        # patch the normalization layers to make it work in fp32 mode
+        patch_norm_fp32(model)
+    # set `fp16_enabled` flag
+    for m in model.modules():
+        if hasattr(m, 'fp16_enabled'):
+            m.fp16_enabled = True
+
+
+def patch_norm_fp32(module):
+    """Recursively convert normalization layers from FP16 to FP32.
+
+    Args:
+        module (nn.Module): The modules to be converted in FP16.
+
+    Returns:
+        nn.Module: The converted module, the normalization layers have been
+            converted to FP32.
+    """
+    if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)):
+        module.float()
+        if isinstance(module, nn.GroupNorm) or torch.__version__ < '1.3':
+            module.forward = patch_forward_method(module.forward, torch.half,
+                                                  torch.float)
+    for child in module.children():
+        patch_norm_fp32(child)
+    return module
+
+
+def patch_forward_method(func, src_type, dst_type, convert_output=True):
+    """Patch the forward method of a module.
+
+    Args:
+        func (callable): The original forward method.
+        src_type (torch.dtype): Type of input arguments to be converted from.
+        dst_type (torch.dtype): Type of input arguments to be converted to.
+        convert_output (bool): Whether to convert the output back to src_type.
+
+    Returns:
+        callable: The patched forward method.
+    """
+
+    def new_forward(*args, **kwargs):
+        output = func(*cast_tensor_type(args, src_type, dst_type),
+                      **cast_tensor_type(kwargs, src_type, dst_type))
+        if convert_output:
+            output = cast_tensor_type(output, dst_type, src_type)
+        return output
+
+    return new_forward
+
+
+class LossScaler:
+    """Class that manages loss scaling in mixed precision training which
+    supports both dynamic or static mode.
+
+    The implementation refers to
+    https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/loss_scaler.py.
+    Indirectly, by supplying ``mode='dynamic'`` for dynamic loss scaling.
+    It's important to understand how :class:`LossScaler` operates.
+    Loss scaling is designed to combat the problem of underflowing
+    gradients encountered at long times when training fp16 networks.
+    Dynamic loss scaling begins by attempting a very high loss
+    scale.  Ironically, this may result in OVERflowing gradients.
+    If overflowing gradients are encountered, :class:`FP16_Optimizer` then
+    skips the update step for this particular iteration/minibatch,
+    and :class:`LossScaler` adjusts the loss scale to a lower value.
+    If a certain number of iterations occur without overflowing gradients
+    detected,:class:`LossScaler` increases the loss scale once more.
+    In this way :class:`LossScaler` attempts to "ride the edge" of always
+    using the highest loss scale possible without incurring overflow.
+
+    Args:
+        init_scale (float): Initial loss scale value, default: 2**32.
+        scale_factor (float): Factor used when adjusting the loss scale.
+            Default: 2.
+        mode (str): Loss scaling mode. 'dynamic' or 'static'
+        scale_window (int): Number of consecutive iterations without an
+            overflow to wait before increasing the loss scale. Default: 1000.
+    """
+
+    def __init__(self,
+                 init_scale=2**32,
+                 mode='dynamic',
+                 scale_factor=2.,
+                 scale_window=1000):
+        self.cur_scale = init_scale
+        self.cur_iter = 0
+        assert mode in ('dynamic',
+                        'static'), 'mode can only be dynamic or static'
+        self.mode = mode
+        self.last_overflow_iter = -1
+        self.scale_factor = scale_factor
+        self.scale_window = scale_window
+
+    def has_overflow(self, params):
+        """Check if params contain overflow."""
+        if self.mode != 'dynamic':
+            return False
+        for p in params:
+            if p.grad is not None and LossScaler._has_inf_or_nan(p.grad.data):
+                return True
+        return False
+
+    def _has_inf_or_nan(x):
+        """Check if params contain NaN."""
+        try:
+            cpu_sum = float(x.float().sum())
+        except RuntimeError as instance:
+            if 'value cannot be converted' not in instance.args[0]:
+                raise
+            return True
+        else:
+            if cpu_sum == float('inf') or cpu_sum == -float('inf') \
+                    or cpu_sum != cpu_sum:
+                return True
+            return False
+
+    def update_scale(self, overflow):
+        """update the current loss scale value when overflow happens."""
+        if self.mode != 'dynamic':
+            return
+        if overflow:
+            self.cur_scale = max(self.cur_scale / self.scale_factor, 1)
+            self.last_overflow_iter = self.cur_iter
+        else:
+            if (self.cur_iter - self.last_overflow_iter) % \
+                    self.scale_window == 0:
+                self.cur_scale *= self.scale_factor
+        self.cur_iter += 1
+
+    def state_dict(self):
+        """Returns the state of the scaler as a :class:`dict`."""
+        return dict(
+            cur_scale=self.cur_scale,
+            cur_iter=self.cur_iter,
+            mode=self.mode,
+            last_overflow_iter=self.last_overflow_iter,
+            scale_factor=self.scale_factor,
+            scale_window=self.scale_window)
+
+    def load_state_dict(self, state_dict):
+        """Loads the loss_scaler state dict.
+
+        Args:
+           state_dict (dict): scaler state.
+        """
+        self.cur_scale = state_dict['cur_scale']
+        self.cur_iter = state_dict['cur_iter']
+        self.mode = state_dict['mode']
+        self.last_overflow_iter = state_dict['last_overflow_iter']
+        self.scale_factor = state_dict['scale_factor']
+        self.scale_window = state_dict['scale_window']
+
+    @property
+    def loss_scale(self):
+        return self.cur_scale
diff --git a/annotator/uniformer/mmcv/runner/hooks/__init__.py b/annotator/uniformer/mmcv/runner/hooks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..915af28cefab14a14c1188ed861161080fd138a3
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .checkpoint import CheckpointHook
+from .closure import ClosureHook
+from .ema import EMAHook
+from .evaluation import DistEvalHook, EvalHook
+from .hook import HOOKS, Hook
+from .iter_timer import IterTimerHook
+from .logger import (DvcliveLoggerHook, LoggerHook, MlflowLoggerHook,
+                     NeptuneLoggerHook, PaviLoggerHook, TensorboardLoggerHook,
+                     TextLoggerHook, WandbLoggerHook)
+from .lr_updater import LrUpdaterHook
+from .memory import EmptyCacheHook
+from .momentum_updater import MomentumUpdaterHook
+from .optimizer import (Fp16OptimizerHook, GradientCumulativeFp16OptimizerHook,
+                        GradientCumulativeOptimizerHook, OptimizerHook)
+from .profiler import ProfilerHook
+from .sampler_seed import DistSamplerSeedHook
+from .sync_buffer import SyncBuffersHook
+
+__all__ = [
+    'HOOKS', 'Hook', 'CheckpointHook', 'ClosureHook', 'LrUpdaterHook',
+    'OptimizerHook', 'Fp16OptimizerHook', 'IterTimerHook',
+    'DistSamplerSeedHook', 'EmptyCacheHook', 'LoggerHook', 'MlflowLoggerHook',
+    'PaviLoggerHook', 'TextLoggerHook', 'TensorboardLoggerHook',
+    'NeptuneLoggerHook', 'WandbLoggerHook', 'DvcliveLoggerHook',
+    'MomentumUpdaterHook', 'SyncBuffersHook', 'EMAHook', 'EvalHook',
+    'DistEvalHook', 'ProfilerHook', 'GradientCumulativeOptimizerHook',
+    'GradientCumulativeFp16OptimizerHook'
+]
diff --git a/annotator/uniformer/mmcv/runner/hooks/checkpoint.py b/annotator/uniformer/mmcv/runner/hooks/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6af3fae43ac4b35532641a81eb13557edfc7dfba
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/checkpoint.py
@@ -0,0 +1,167 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+
+from annotator.uniformer.mmcv.fileio import FileClient
+from ..dist_utils import allreduce_params, master_only
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class CheckpointHook(Hook):
+    """Save checkpoints periodically.
+
+    Args:
+        interval (int): The saving period. If ``by_epoch=True``, interval
+            indicates epochs, otherwise it indicates iterations.
+            Default: -1, which means "never".
+        by_epoch (bool): Saving checkpoints by epoch or by iteration.
+            Default: True.
+        save_optimizer (bool): Whether to save optimizer state_dict in the
+            checkpoint. It is usually used for resuming experiments.
+            Default: True.
+        out_dir (str, optional): The root directory to save checkpoints. If not
+            specified, ``runner.work_dir`` will be used by default. If
+            specified, the ``out_dir`` will be the concatenation of ``out_dir``
+            and the last level directory of ``runner.work_dir``.
+            `Changed in version 1.3.16.`
+        max_keep_ckpts (int, optional): The maximum checkpoints to keep.
+            In some cases we want only the latest few checkpoints and would
+            like to delete old ones to save the disk space.
+            Default: -1, which means unlimited.
+        save_last (bool, optional): Whether to force the last checkpoint to be
+            saved regardless of interval. Default: True.
+        sync_buffer (bool, optional): Whether to synchronize buffers in
+            different gpus. Default: False.
+        file_client_args (dict, optional): Arguments to instantiate a
+            FileClient. See :class:`mmcv.fileio.FileClient` for details.
+            Default: None.
+            `New in version 1.3.16.`
+
+    .. warning::
+        Before v1.3.16, the ``out_dir`` argument indicates the path where the
+        checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the
+        root directory and the final path to save checkpoint is the
+        concatenation of ``out_dir`` and the last level directory of
+        ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A"
+        and the value of ``runner.work_dir`` is "/path/of/B", then the final
+        path will be "/path/of/A/B".
+    """
+
+    def __init__(self,
+                 interval=-1,
+                 by_epoch=True,
+                 save_optimizer=True,
+                 out_dir=None,
+                 max_keep_ckpts=-1,
+                 save_last=True,
+                 sync_buffer=False,
+                 file_client_args=None,
+                 **kwargs):
+        self.interval = interval
+        self.by_epoch = by_epoch
+        self.save_optimizer = save_optimizer
+        self.out_dir = out_dir
+        self.max_keep_ckpts = max_keep_ckpts
+        self.save_last = save_last
+        self.args = kwargs
+        self.sync_buffer = sync_buffer
+        self.file_client_args = file_client_args
+
+    def before_run(self, runner):
+        if not self.out_dir:
+            self.out_dir = runner.work_dir
+
+        self.file_client = FileClient.infer_client(self.file_client_args,
+                                                   self.out_dir)
+
+        # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+        # `self.out_dir` is set so the final `self.out_dir` is the
+        # concatenation of `self.out_dir` and the last level directory of
+        # `runner.work_dir`
+        if self.out_dir != runner.work_dir:
+            basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+            self.out_dir = self.file_client.join_path(self.out_dir, basename)
+
+        runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
+                            f'{self.file_client.name}.'))
+
+        # disable the create_symlink option because some file backends do not
+        # allow to create a symlink
+        if 'create_symlink' in self.args:
+            if self.args[
+                    'create_symlink'] and not self.file_client.allow_symlink:
+                self.args['create_symlink'] = False
+                warnings.warn(
+                    ('create_symlink is set as True by the user but is changed'
+                     'to be False because creating symbolic link is not '
+                     f'allowed in {self.file_client.name}'))
+        else:
+            self.args['create_symlink'] = self.file_client.allow_symlink
+
+    def after_train_epoch(self, runner):
+        if not self.by_epoch:
+            return
+
+        # save checkpoint for following cases:
+        # 1. every ``self.interval`` epochs
+        # 2. reach the last epoch of training
+        if self.every_n_epochs(
+                runner, self.interval) or (self.save_last
+                                           and self.is_last_epoch(runner)):
+            runner.logger.info(
+                f'Saving checkpoint at {runner.epoch + 1} epochs')
+            if self.sync_buffer:
+                allreduce_params(runner.model.buffers())
+            self._save_checkpoint(runner)
+
+    @master_only
+    def _save_checkpoint(self, runner):
+        """Save the current checkpoint and delete unwanted checkpoint."""
+        runner.save_checkpoint(
+            self.out_dir, save_optimizer=self.save_optimizer, **self.args)
+        if runner.meta is not None:
+            if self.by_epoch:
+                cur_ckpt_filename = self.args.get(
+                    'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
+            else:
+                cur_ckpt_filename = self.args.get(
+                    'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
+            runner.meta.setdefault('hook_msgs', dict())
+            runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
+                self.out_dir, cur_ckpt_filename)
+        # remove other checkpoints
+        if self.max_keep_ckpts > 0:
+            if self.by_epoch:
+                name = 'epoch_{}.pth'
+                current_ckpt = runner.epoch + 1
+            else:
+                name = 'iter_{}.pth'
+                current_ckpt = runner.iter + 1
+            redundant_ckpts = range(
+                current_ckpt - self.max_keep_ckpts * self.interval, 0,
+                -self.interval)
+            filename_tmpl = self.args.get('filename_tmpl', name)
+            for _step in redundant_ckpts:
+                ckpt_path = self.file_client.join_path(
+                    self.out_dir, filename_tmpl.format(_step))
+                if self.file_client.isfile(ckpt_path):
+                    self.file_client.remove(ckpt_path)
+                else:
+                    break
+
+    def after_train_iter(self, runner):
+        if self.by_epoch:
+            return
+
+        # save checkpoint for following cases:
+        # 1. every ``self.interval`` iterations
+        # 2. reach the last iteration of training
+        if self.every_n_iters(
+                runner, self.interval) or (self.save_last
+                                           and self.is_last_iter(runner)):
+            runner.logger.info(
+                f'Saving checkpoint at {runner.iter + 1} iterations')
+            if self.sync_buffer:
+                allreduce_params(runner.model.buffers())
+            self._save_checkpoint(runner)
diff --git a/annotator/uniformer/mmcv/runner/hooks/closure.py b/annotator/uniformer/mmcv/runner/hooks/closure.py
new file mode 100644
index 0000000000000000000000000000000000000000..b955f81f425be4ac3e6bb3f4aac653887989e872
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/closure.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class ClosureHook(Hook):
+
+    def __init__(self, fn_name, fn):
+        assert hasattr(self, fn_name)
+        assert callable(fn)
+        setattr(self, fn_name, fn)
diff --git a/annotator/uniformer/mmcv/runner/hooks/ema.py b/annotator/uniformer/mmcv/runner/hooks/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..15c7e68088f019802a59e7ae41cc1fe0c7f28f96
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/ema.py
@@ -0,0 +1,89 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...parallel import is_module_wrapper
+from ..hooks.hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class EMAHook(Hook):
+    r"""Exponential Moving Average Hook.
+
+    Use Exponential Moving Average on all parameters of model in training
+    process. All parameters have a ema backup, which update by the formula
+    as below. EMAHook takes priority over EvalHook and CheckpointSaverHook.
+
+        .. math::
+
+            \text{Xema\_{t+1}} = (1 - \text{momentum}) \times
+            \text{Xema\_{t}} +  \text{momentum} \times X_t
+
+    Args:
+        momentum (float): The momentum used for updating ema parameter.
+            Defaults to 0.0002.
+        interval (int): Update ema parameter every interval iteration.
+            Defaults to 1.
+        warm_up (int): During first warm_up steps, we may use smaller momentum
+            to update ema parameters more slowly. Defaults to 100.
+        resume_from (str): The checkpoint path. Defaults to None.
+    """
+
+    def __init__(self,
+                 momentum=0.0002,
+                 interval=1,
+                 warm_up=100,
+                 resume_from=None):
+        assert isinstance(interval, int) and interval > 0
+        self.warm_up = warm_up
+        self.interval = interval
+        assert momentum > 0 and momentum < 1
+        self.momentum = momentum**interval
+        self.checkpoint = resume_from
+
+    def before_run(self, runner):
+        """To resume model with it's ema parameters more friendly.
+
+        Register ema parameter as ``named_buffer`` to model
+        """
+        model = runner.model
+        if is_module_wrapper(model):
+            model = model.module
+        self.param_ema_buffer = {}
+        self.model_parameters = dict(model.named_parameters(recurse=True))
+        for name, value in self.model_parameters.items():
+            # "." is not allowed in module's buffer name
+            buffer_name = f"ema_{name.replace('.', '_')}"
+            self.param_ema_buffer[name] = buffer_name
+            model.register_buffer(buffer_name, value.data.clone())
+        self.model_buffers = dict(model.named_buffers(recurse=True))
+        if self.checkpoint is not None:
+            runner.resume(self.checkpoint)
+
+    def after_train_iter(self, runner):
+        """Update ema parameter every self.interval iterations."""
+        curr_step = runner.iter
+        # We warm up the momentum considering the instability at beginning
+        momentum = min(self.momentum,
+                       (1 + curr_step) / (self.warm_up + curr_step))
+        if curr_step % self.interval != 0:
+            return
+        for name, parameter in self.model_parameters.items():
+            buffer_name = self.param_ema_buffer[name]
+            buffer_parameter = self.model_buffers[buffer_name]
+            buffer_parameter.mul_(1 - momentum).add_(momentum, parameter.data)
+
+    def after_train_epoch(self, runner):
+        """We load parameter values from ema backup to model before the
+        EvalHook."""
+        self._swap_ema_parameters()
+
+    def before_train_epoch(self, runner):
+        """We recover model's parameter from ema backup after last epoch's
+        EvalHook."""
+        self._swap_ema_parameters()
+
+    def _swap_ema_parameters(self):
+        """Swap the parameter of model with parameter in ema_buffer."""
+        for name, value in self.model_parameters.items():
+            temp = value.data.clone()
+            ema_buffer = self.model_buffers[self.param_ema_buffer[name]]
+            value.data.copy_(ema_buffer.data)
+            ema_buffer.data.copy_(temp)
diff --git a/annotator/uniformer/mmcv/runner/hooks/evaluation.py b/annotator/uniformer/mmcv/runner/hooks/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d00999ce5665c53bded8de9e084943eee2d230d
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/evaluation.py
@@ -0,0 +1,509 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import warnings
+from math import inf
+
+import torch.distributed as dist
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.utils.data import DataLoader
+
+from annotator.uniformer.mmcv.fileio import FileClient
+from annotator.uniformer.mmcv.utils import is_seq_of
+from .hook import Hook
+from .logger import LoggerHook
+
+
+class EvalHook(Hook):
+    """Non-Distributed evaluation hook.
+
+    This hook will regularly perform evaluation in a given interval when
+    performing in non-distributed environment.
+
+    Args:
+        dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+            implemented ``evaluate`` function.
+        start (int | None, optional): Evaluation starting epoch. It enables
+            evaluation before the training starts if ``start`` <= the resuming
+            epoch. If None, whether to evaluate is merely decided by
+            ``interval``. Default: None.
+        interval (int): Evaluation interval. Default: 1.
+        by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+            If set to True, it will perform by epoch. Otherwise, by iteration.
+            Default: True.
+        save_best (str, optional): If a metric is specified, it would measure
+            the best checkpoint during evaluation. The information about best
+            checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+            best score value and best checkpoint path, which will be also
+            loaded when resume checkpoint. Options are the evaluation metrics
+            on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+            detection and instance segmentation. ``AR@100`` for proposal
+            recall. If ``save_best`` is ``auto``, the first key of the returned
+            ``OrderedDict`` result will be used. Default: None.
+        rule (str | None, optional): Comparison rule for best score. If set to
+            None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+            .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+            be inferred by 'less' rule. Options are 'greater', 'less', None.
+            Default: None.
+        test_fn (callable, optional): test a model with samples from a
+            dataloader, and return the test results. If ``None``, the default
+            test function ``mmcv.engine.single_gpu_test`` will be used.
+            (default: ``None``)
+        greater_keys (List[str] | None, optional): Metric keys that will be
+            inferred by 'greater' comparison rule. If ``None``,
+            _default_greater_keys will be used. (default: ``None``)
+        less_keys (List[str] | None, optional): Metric keys that will be
+            inferred by 'less' comparison rule. If ``None``, _default_less_keys
+            will be used. (default: ``None``)
+        out_dir (str, optional): The root directory to save checkpoints. If not
+            specified, `runner.work_dir` will be used by default. If specified,
+            the `out_dir` will be the concatenation of `out_dir` and the last
+            level directory of `runner.work_dir`.
+            `New in version 1.3.16.`
+        file_client_args (dict): Arguments to instantiate a FileClient.
+            See :class:`mmcv.fileio.FileClient` for details. Default: None.
+            `New in version 1.3.16.`
+        **eval_kwargs: Evaluation arguments fed into the evaluate function of
+            the dataset.
+
+    Notes:
+        If new arguments are added for EvalHook, tools/test.py,
+        tools/eval_metric.py may be affected.
+    """
+
+    # Since the key for determine greater or less is related to the downstream
+    # tasks, downstream repos may need to overwrite the following inner
+    # variable accordingly.
+
+    rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
+    init_value_map = {'greater': -inf, 'less': inf}
+    _default_greater_keys = [
+        'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU',
+        'mAcc', 'aAcc'
+    ]
+    _default_less_keys = ['loss']
+
+    def __init__(self,
+                 dataloader,
+                 start=None,
+                 interval=1,
+                 by_epoch=True,
+                 save_best=None,
+                 rule=None,
+                 test_fn=None,
+                 greater_keys=None,
+                 less_keys=None,
+                 out_dir=None,
+                 file_client_args=None,
+                 **eval_kwargs):
+        if not isinstance(dataloader, DataLoader):
+            raise TypeError(f'dataloader must be a pytorch DataLoader, '
+                            f'but got {type(dataloader)}')
+
+        if interval <= 0:
+            raise ValueError(f'interval must be a positive number, '
+                             f'but got {interval}')
+
+        assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean'
+
+        if start is not None and start < 0:
+            raise ValueError(f'The evaluation start epoch {start} is smaller '
+                             f'than 0')
+
+        self.dataloader = dataloader
+        self.interval = interval
+        self.start = start
+        self.by_epoch = by_epoch
+
+        assert isinstance(save_best, str) or save_best is None, \
+            '""save_best"" should be a str or None ' \
+            f'rather than {type(save_best)}'
+        self.save_best = save_best
+        self.eval_kwargs = eval_kwargs
+        self.initial_flag = True
+
+        if test_fn is None:
+            from annotator.uniformer.mmcv.engine import single_gpu_test
+            self.test_fn = single_gpu_test
+        else:
+            self.test_fn = test_fn
+
+        if greater_keys is None:
+            self.greater_keys = self._default_greater_keys
+        else:
+            if not isinstance(greater_keys, (list, tuple)):
+                greater_keys = (greater_keys, )
+            assert is_seq_of(greater_keys, str)
+            self.greater_keys = greater_keys
+
+        if less_keys is None:
+            self.less_keys = self._default_less_keys
+        else:
+            if not isinstance(less_keys, (list, tuple)):
+                less_keys = (less_keys, )
+            assert is_seq_of(less_keys, str)
+            self.less_keys = less_keys
+
+        if self.save_best is not None:
+            self.best_ckpt_path = None
+            self._init_rule(rule, self.save_best)
+
+        self.out_dir = out_dir
+        self.file_client_args = file_client_args
+
+    def _init_rule(self, rule, key_indicator):
+        """Initialize rule, key_indicator, comparison_func, and best score.
+
+        Here is the rule to determine which rule is used for key indicator
+        when the rule is not specific (note that the key indicator matching
+        is case-insensitive):
+        1. If the key indicator is in ``self.greater_keys``, the rule will be
+           specified as 'greater'.
+        2. Or if the key indicator is in ``self.less_keys``, the rule will be
+           specified as 'less'.
+        3. Or if the key indicator is equal to the substring in any one item
+           in ``self.greater_keys``, the rule will be specified as 'greater'.
+        4. Or if the key indicator is equal to the substring in any one item
+           in ``self.less_keys``, the rule will be specified as 'less'.
+
+        Args:
+            rule (str | None): Comparison rule for best score.
+            key_indicator (str | None): Key indicator to determine the
+                comparison rule.
+        """
+        if rule not in self.rule_map and rule is not None:
+            raise KeyError(f'rule must be greater, less or None, '
+                           f'but got {rule}.')
+
+        if rule is None:
+            if key_indicator != 'auto':
+                # `_lc` here means we use the lower case of keys for
+                # case-insensitive matching
+                key_indicator_lc = key_indicator.lower()
+                greater_keys = [key.lower() for key in self.greater_keys]
+                less_keys = [key.lower() for key in self.less_keys]
+
+                if key_indicator_lc in greater_keys:
+                    rule = 'greater'
+                elif key_indicator_lc in less_keys:
+                    rule = 'less'
+                elif any(key in key_indicator_lc for key in greater_keys):
+                    rule = 'greater'
+                elif any(key in key_indicator_lc for key in less_keys):
+                    rule = 'less'
+                else:
+                    raise ValueError(f'Cannot infer the rule for key '
+                                     f'{key_indicator}, thus a specific rule '
+                                     f'must be specified.')
+        self.rule = rule
+        self.key_indicator = key_indicator
+        if self.rule is not None:
+            self.compare_func = self.rule_map[self.rule]
+
+    def before_run(self, runner):
+        if not self.out_dir:
+            self.out_dir = runner.work_dir
+
+        self.file_client = FileClient.infer_client(self.file_client_args,
+                                                   self.out_dir)
+
+        # if `self.out_dir` is not equal to `runner.work_dir`, it means that
+        # `self.out_dir` is set so the final `self.out_dir` is the
+        # concatenation of `self.out_dir` and the last level directory of
+        # `runner.work_dir`
+        if self.out_dir != runner.work_dir:
+            basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+            self.out_dir = self.file_client.join_path(self.out_dir, basename)
+            runner.logger.info(
+                (f'The best checkpoint will be saved to {self.out_dir} by '
+                 f'{self.file_client.name}'))
+
+        if self.save_best is not None:
+            if runner.meta is None:
+                warnings.warn('runner.meta is None. Creating an empty one.')
+                runner.meta = dict()
+            runner.meta.setdefault('hook_msgs', dict())
+            self.best_ckpt_path = runner.meta['hook_msgs'].get(
+                'best_ckpt', None)
+
+    def before_train_iter(self, runner):
+        """Evaluate the model only at the start of training by iteration."""
+        if self.by_epoch or not self.initial_flag:
+            return
+        if self.start is not None and runner.iter >= self.start:
+            self.after_train_iter(runner)
+        self.initial_flag = False
+
+    def before_train_epoch(self, runner):
+        """Evaluate the model only at the start of training by epoch."""
+        if not (self.by_epoch and self.initial_flag):
+            return
+        if self.start is not None and runner.epoch >= self.start:
+            self.after_train_epoch(runner)
+        self.initial_flag = False
+
+    def after_train_iter(self, runner):
+        """Called after every training iter to evaluate the results."""
+        if not self.by_epoch and self._should_evaluate(runner):
+            # Because the priority of EvalHook is higher than LoggerHook, the
+            # training log and the evaluating log are mixed. Therefore,
+            # we need to dump the training log and clear it before evaluating
+            # log is generated. In addition, this problem will only appear in
+            # `IterBasedRunner` whose `self.by_epoch` is False, because
+            # `EpochBasedRunner` whose `self.by_epoch` is True calls
+            # `_do_evaluate` in `after_train_epoch` stage, and at this stage
+            # the training log has been printed, so it will not cause any
+            # problem. more details at
+            # https://github.com/open-mmlab/mmsegmentation/issues/694
+            for hook in runner._hooks:
+                if isinstance(hook, LoggerHook):
+                    hook.after_train_iter(runner)
+            runner.log_buffer.clear()
+
+            self._do_evaluate(runner)
+
+    def after_train_epoch(self, runner):
+        """Called after every training epoch to evaluate the results."""
+        if self.by_epoch and self._should_evaluate(runner):
+            self._do_evaluate(runner)
+
+    def _do_evaluate(self, runner):
+        """perform evaluation and save ckpt."""
+        results = self.test_fn(runner.model, self.dataloader)
+        runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+        key_score = self.evaluate(runner, results)
+        # the key_score may be `None` so it needs to skip the action to save
+        # the best checkpoint
+        if self.save_best and key_score:
+            self._save_ckpt(runner, key_score)
+
+    def _should_evaluate(self, runner):
+        """Judge whether to perform evaluation.
+
+        Here is the rule to judge whether to perform evaluation:
+        1. It will not perform evaluation during the epoch/iteration interval,
+           which is determined by ``self.interval``.
+        2. It will not perform evaluation if the start time is larger than
+           current time.
+        3. It will not perform evaluation when current time is larger than
+           the start time but during epoch/iteration interval.
+
+        Returns:
+            bool: The flag indicating whether to perform evaluation.
+        """
+        if self.by_epoch:
+            current = runner.epoch
+            check_time = self.every_n_epochs
+        else:
+            current = runner.iter
+            check_time = self.every_n_iters
+
+        if self.start is None:
+            if not check_time(runner, self.interval):
+                # No evaluation during the interval.
+                return False
+        elif (current + 1) < self.start:
+            # No evaluation if start is larger than the current time.
+            return False
+        else:
+            # Evaluation only at epochs/iters 3, 5, 7...
+            # if start==3 and interval==2
+            if (current + 1 - self.start) % self.interval:
+                return False
+        return True
+
+    def _save_ckpt(self, runner, key_score):
+        """Save the best checkpoint.
+
+        It will compare the score according to the compare function, write
+        related information (best score, best checkpoint path) and save the
+        best checkpoint into ``work_dir``.
+        """
+        if self.by_epoch:
+            current = f'epoch_{runner.epoch + 1}'
+            cur_type, cur_time = 'epoch', runner.epoch + 1
+        else:
+            current = f'iter_{runner.iter + 1}'
+            cur_type, cur_time = 'iter', runner.iter + 1
+
+        best_score = runner.meta['hook_msgs'].get(
+            'best_score', self.init_value_map[self.rule])
+        if self.compare_func(key_score, best_score):
+            best_score = key_score
+            runner.meta['hook_msgs']['best_score'] = best_score
+
+            if self.best_ckpt_path and self.file_client.isfile(
+                    self.best_ckpt_path):
+                self.file_client.remove(self.best_ckpt_path)
+                runner.logger.info(
+                    (f'The previous best checkpoint {self.best_ckpt_path} was '
+                     'removed'))
+
+            best_ckpt_name = f'best_{self.key_indicator}_{current}.pth'
+            self.best_ckpt_path = self.file_client.join_path(
+                self.out_dir, best_ckpt_name)
+            runner.meta['hook_msgs']['best_ckpt'] = self.best_ckpt_path
+
+            runner.save_checkpoint(
+                self.out_dir, best_ckpt_name, create_symlink=False)
+            runner.logger.info(
+                f'Now best checkpoint is saved as {best_ckpt_name}.')
+            runner.logger.info(
+                f'Best {self.key_indicator} is {best_score:0.4f} '
+                f'at {cur_time} {cur_type}.')
+
+    def evaluate(self, runner, results):
+        """Evaluate the results.
+
+        Args:
+            runner (:obj:`mmcv.Runner`): The underlined training runner.
+            results (list): Output results.
+        """
+        eval_res = self.dataloader.dataset.evaluate(
+            results, logger=runner.logger, **self.eval_kwargs)
+
+        for name, val in eval_res.items():
+            runner.log_buffer.output[name] = val
+        runner.log_buffer.ready = True
+
+        if self.save_best is not None:
+            # If the performance of model is pool, the `eval_res` may be an
+            # empty dict and it will raise exception when `self.save_best` is
+            # not None. More details at
+            # https://github.com/open-mmlab/mmdetection/issues/6265.
+            if not eval_res:
+                warnings.warn(
+                    'Since `eval_res` is an empty dict, the behavior to save '
+                    'the best checkpoint will be skipped in this evaluation.')
+                return None
+
+            if self.key_indicator == 'auto':
+                # infer from eval_results
+                self._init_rule(self.rule, list(eval_res.keys())[0])
+            return eval_res[self.key_indicator]
+
+        return None
+
+
+class DistEvalHook(EvalHook):
+    """Distributed evaluation hook.
+
+    This hook will regularly perform evaluation in a given interval when
+    performing in distributed environment.
+
+    Args:
+        dataloader (DataLoader): A PyTorch dataloader, whose dataset has
+            implemented ``evaluate`` function.
+        start (int | None, optional): Evaluation starting epoch. It enables
+            evaluation before the training starts if ``start`` <= the resuming
+            epoch. If None, whether to evaluate is merely decided by
+            ``interval``. Default: None.
+        interval (int): Evaluation interval. Default: 1.
+        by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+            If set to True, it will perform by epoch. Otherwise, by iteration.
+            default: True.
+        save_best (str, optional): If a metric is specified, it would measure
+            the best checkpoint during evaluation. The information about best
+            checkpoint would be saved in ``runner.meta['hook_msgs']`` to keep
+            best score value and best checkpoint path, which will be also
+            loaded when resume checkpoint. Options are the evaluation metrics
+            on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox
+            detection and instance segmentation. ``AR@100`` for proposal
+            recall. If ``save_best`` is ``auto``, the first key of the returned
+            ``OrderedDict`` result will be used. Default: None.
+        rule (str | None, optional): Comparison rule for best score. If set to
+            None, it will infer a reasonable rule. Keys such as 'acc', 'top'
+            .etc will be inferred by 'greater' rule. Keys contain 'loss' will
+            be inferred by 'less' rule. Options are 'greater', 'less', None.
+            Default: None.
+        test_fn (callable, optional): test a model with samples from a
+            dataloader in a multi-gpu manner, and return the test results. If
+            ``None``, the default test function ``mmcv.engine.multi_gpu_test``
+            will be used. (default: ``None``)
+        tmpdir (str | None): Temporary directory to save the results of all
+            processes. Default: None.
+        gpu_collect (bool): Whether to use gpu or cpu to collect results.
+            Default: False.
+        broadcast_bn_buffer (bool): Whether to broadcast the
+            buffer(running_mean and running_var) of rank 0 to other rank
+            before evaluation. Default: True.
+        out_dir (str, optional): The root directory to save checkpoints. If not
+            specified, `runner.work_dir` will be used by default. If specified,
+            the `out_dir` will be the concatenation of `out_dir` and the last
+            level directory of `runner.work_dir`.
+        file_client_args (dict): Arguments to instantiate a FileClient.
+            See :class:`mmcv.fileio.FileClient` for details. Default: None.
+        **eval_kwargs: Evaluation arguments fed into the evaluate function of
+            the dataset.
+    """
+
+    def __init__(self,
+                 dataloader,
+                 start=None,
+                 interval=1,
+                 by_epoch=True,
+                 save_best=None,
+                 rule=None,
+                 test_fn=None,
+                 greater_keys=None,
+                 less_keys=None,
+                 broadcast_bn_buffer=True,
+                 tmpdir=None,
+                 gpu_collect=False,
+                 out_dir=None,
+                 file_client_args=None,
+                 **eval_kwargs):
+
+        if test_fn is None:
+            from annotator.uniformer.mmcv.engine import multi_gpu_test
+            test_fn = multi_gpu_test
+
+        super().__init__(
+            dataloader,
+            start=start,
+            interval=interval,
+            by_epoch=by_epoch,
+            save_best=save_best,
+            rule=rule,
+            test_fn=test_fn,
+            greater_keys=greater_keys,
+            less_keys=less_keys,
+            out_dir=out_dir,
+            file_client_args=file_client_args,
+            **eval_kwargs)
+
+        self.broadcast_bn_buffer = broadcast_bn_buffer
+        self.tmpdir = tmpdir
+        self.gpu_collect = gpu_collect
+
+    def _do_evaluate(self, runner):
+        """perform evaluation and save ckpt."""
+        # Synchronization of BatchNorm's buffer (running_mean
+        # and running_var) is not supported in the DDP of pytorch,
+        # which may cause the inconsistent performance of models in
+        # different ranks, so we broadcast BatchNorm's buffers
+        # of rank 0 to other ranks to avoid this.
+        if self.broadcast_bn_buffer:
+            model = runner.model
+            for name, module in model.named_modules():
+                if isinstance(module,
+                              _BatchNorm) and module.track_running_stats:
+                    dist.broadcast(module.running_var, 0)
+                    dist.broadcast(module.running_mean, 0)
+
+        tmpdir = self.tmpdir
+        if tmpdir is None:
+            tmpdir = osp.join(runner.work_dir, '.eval_hook')
+
+        results = self.test_fn(
+            runner.model,
+            self.dataloader,
+            tmpdir=tmpdir,
+            gpu_collect=self.gpu_collect)
+        if runner.rank == 0:
+            print('\n')
+            runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
+            key_score = self.evaluate(runner, results)
+            # the key_score may be `None` so it needs to skip the action to
+            # save the best checkpoint
+            if self.save_best and key_score:
+                self._save_ckpt(runner, key_score)
diff --git a/annotator/uniformer/mmcv/runner/hooks/hook.py b/annotator/uniformer/mmcv/runner/hooks/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8855c107727ecf85b917c890fc8b7f6359238a4
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/hook.py
@@ -0,0 +1,92 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from annotator.uniformer.mmcv.utils import Registry, is_method_overridden
+
+HOOKS = Registry('hook')
+
+
+class Hook:
+    stages = ('before_run', 'before_train_epoch', 'before_train_iter',
+              'after_train_iter', 'after_train_epoch', 'before_val_epoch',
+              'before_val_iter', 'after_val_iter', 'after_val_epoch',
+              'after_run')
+
+    def before_run(self, runner):
+        pass
+
+    def after_run(self, runner):
+        pass
+
+    def before_epoch(self, runner):
+        pass
+
+    def after_epoch(self, runner):
+        pass
+
+    def before_iter(self, runner):
+        pass
+
+    def after_iter(self, runner):
+        pass
+
+    def before_train_epoch(self, runner):
+        self.before_epoch(runner)
+
+    def before_val_epoch(self, runner):
+        self.before_epoch(runner)
+
+    def after_train_epoch(self, runner):
+        self.after_epoch(runner)
+
+    def after_val_epoch(self, runner):
+        self.after_epoch(runner)
+
+    def before_train_iter(self, runner):
+        self.before_iter(runner)
+
+    def before_val_iter(self, runner):
+        self.before_iter(runner)
+
+    def after_train_iter(self, runner):
+        self.after_iter(runner)
+
+    def after_val_iter(self, runner):
+        self.after_iter(runner)
+
+    def every_n_epochs(self, runner, n):
+        return (runner.epoch + 1) % n == 0 if n > 0 else False
+
+    def every_n_inner_iters(self, runner, n):
+        return (runner.inner_iter + 1) % n == 0 if n > 0 else False
+
+    def every_n_iters(self, runner, n):
+        return (runner.iter + 1) % n == 0 if n > 0 else False
+
+    def end_of_epoch(self, runner):
+        return runner.inner_iter + 1 == len(runner.data_loader)
+
+    def is_last_epoch(self, runner):
+        return runner.epoch + 1 == runner._max_epochs
+
+    def is_last_iter(self, runner):
+        return runner.iter + 1 == runner._max_iters
+
+    def get_triggered_stages(self):
+        trigger_stages = set()
+        for stage in Hook.stages:
+            if is_method_overridden(stage, Hook, self):
+                trigger_stages.add(stage)
+
+        # some methods will be triggered in multi stages
+        # use this dict to map method to stages.
+        method_stages_map = {
+            'before_epoch': ['before_train_epoch', 'before_val_epoch'],
+            'after_epoch': ['after_train_epoch', 'after_val_epoch'],
+            'before_iter': ['before_train_iter', 'before_val_iter'],
+            'after_iter': ['after_train_iter', 'after_val_iter'],
+        }
+
+        for method, map_stages in method_stages_map.items():
+            if is_method_overridden(method, Hook, self):
+                trigger_stages.update(map_stages)
+
+        return [stage for stage in Hook.stages if stage in trigger_stages]
diff --git a/annotator/uniformer/mmcv/runner/hooks/iter_timer.py b/annotator/uniformer/mmcv/runner/hooks/iter_timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfd5002fe85ffc6992155ac01003878064a1d9be
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/iter_timer.py
@@ -0,0 +1,18 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import time
+
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class IterTimerHook(Hook):
+
+    def before_epoch(self, runner):
+        self.t = time.time()
+
+    def before_iter(self, runner):
+        runner.log_buffer.update({'data_time': time.time() - self.t})
+
+    def after_iter(self, runner):
+        runner.log_buffer.update({'time': time.time() - self.t})
+        self.t = time.time()
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py b/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b6b345640a895368ac8a647afef6f24333d90e
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base import LoggerHook
+from .dvclive import DvcliveLoggerHook
+from .mlflow import MlflowLoggerHook
+from .neptune import NeptuneLoggerHook
+from .pavi import PaviLoggerHook
+from .tensorboard import TensorboardLoggerHook
+from .text import TextLoggerHook
+from .wandb import WandbLoggerHook
+
+__all__ = [
+    'LoggerHook', 'MlflowLoggerHook', 'PaviLoggerHook',
+    'TensorboardLoggerHook', 'TextLoggerHook', 'WandbLoggerHook',
+    'NeptuneLoggerHook', 'DvcliveLoggerHook'
+]
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/base.py b/annotator/uniformer/mmcv/runner/hooks/logger/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f845256729458ced821762a1b8ef881e17ff9955
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/base.py
@@ -0,0 +1,166 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from abc import ABCMeta, abstractmethod
+
+import numpy as np
+import torch
+
+from ..hook import Hook
+
+
+class LoggerHook(Hook):
+    """Base class for logger hooks.
+
+    Args:
+        interval (int): Logging interval (every k iterations).
+        ignore_last (bool): Ignore the log of last iterations in each epoch
+            if less than `interval`.
+        reset_flag (bool): Whether to clear the output buffer after logging.
+        by_epoch (bool): Whether EpochBasedRunner is used.
+    """
+
+    __metaclass__ = ABCMeta
+
+    def __init__(self,
+                 interval=10,
+                 ignore_last=True,
+                 reset_flag=False,
+                 by_epoch=True):
+        self.interval = interval
+        self.ignore_last = ignore_last
+        self.reset_flag = reset_flag
+        self.by_epoch = by_epoch
+
+    @abstractmethod
+    def log(self, runner):
+        pass
+
+    @staticmethod
+    def is_scalar(val, include_np=True, include_torch=True):
+        """Tell the input variable is a scalar or not.
+
+        Args:
+            val: Input variable.
+            include_np (bool): Whether include 0-d np.ndarray as a scalar.
+            include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
+
+        Returns:
+            bool: True or False.
+        """
+        if isinstance(val, numbers.Number):
+            return True
+        elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
+            return True
+        elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
+            return True
+        else:
+            return False
+
+    def get_mode(self, runner):
+        if runner.mode == 'train':
+            if 'time' in runner.log_buffer.output:
+                mode = 'train'
+            else:
+                mode = 'val'
+        elif runner.mode == 'val':
+            mode = 'val'
+        else:
+            raise ValueError(f"runner mode should be 'train' or 'val', "
+                             f'but got {runner.mode}')
+        return mode
+
+    def get_epoch(self, runner):
+        if runner.mode == 'train':
+            epoch = runner.epoch + 1
+        elif runner.mode == 'val':
+            # normal val mode
+            # runner.epoch += 1 has been done before val workflow
+            epoch = runner.epoch
+        else:
+            raise ValueError(f"runner mode should be 'train' or 'val', "
+                             f'but got {runner.mode}')
+        return epoch
+
+    def get_iter(self, runner, inner_iter=False):
+        """Get the current training iteration step."""
+        if self.by_epoch and inner_iter:
+            current_iter = runner.inner_iter + 1
+        else:
+            current_iter = runner.iter + 1
+        return current_iter
+
+    def get_lr_tags(self, runner):
+        tags = {}
+        lrs = runner.current_lr()
+        if isinstance(lrs, dict):
+            for name, value in lrs.items():
+                tags[f'learning_rate/{name}'] = value[0]
+        else:
+            tags['learning_rate'] = lrs[0]
+        return tags
+
+    def get_momentum_tags(self, runner):
+        tags = {}
+        momentums = runner.current_momentum()
+        if isinstance(momentums, dict):
+            for name, value in momentums.items():
+                tags[f'momentum/{name}'] = value[0]
+        else:
+            tags['momentum'] = momentums[0]
+        return tags
+
+    def get_loggable_tags(self,
+                          runner,
+                          allow_scalar=True,
+                          allow_text=False,
+                          add_mode=True,
+                          tags_to_skip=('time', 'data_time')):
+        tags = {}
+        for var, val in runner.log_buffer.output.items():
+            if var in tags_to_skip:
+                continue
+            if self.is_scalar(val) and not allow_scalar:
+                continue
+            if isinstance(val, str) and not allow_text:
+                continue
+            if add_mode:
+                var = f'{self.get_mode(runner)}/{var}'
+            tags[var] = val
+        tags.update(self.get_lr_tags(runner))
+        tags.update(self.get_momentum_tags(runner))
+        return tags
+
+    def before_run(self, runner):
+        for hook in runner.hooks[::-1]:
+            if isinstance(hook, LoggerHook):
+                hook.reset_flag = True
+                break
+
+    def before_epoch(self, runner):
+        runner.log_buffer.clear()  # clear logs of last epoch
+
+    def after_train_iter(self, runner):
+        if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
+            runner.log_buffer.average(self.interval)
+        elif not self.by_epoch and self.every_n_iters(runner, self.interval):
+            runner.log_buffer.average(self.interval)
+        elif self.end_of_epoch(runner) and not self.ignore_last:
+            # not precise but more stable
+            runner.log_buffer.average(self.interval)
+
+        if runner.log_buffer.ready:
+            self.log(runner)
+            if self.reset_flag:
+                runner.log_buffer.clear_output()
+
+    def after_train_epoch(self, runner):
+        if runner.log_buffer.ready:
+            self.log(runner)
+            if self.reset_flag:
+                runner.log_buffer.clear_output()
+
+    def after_val_epoch(self, runner):
+        runner.log_buffer.average()
+        self.log(runner)
+        if self.reset_flag:
+            runner.log_buffer.clear_output()
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py b/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py
new file mode 100644
index 0000000000000000000000000000000000000000..687cdc58c0336c92b1e4f9a410ba67ebaab2bc7a
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/dvclive.py
@@ -0,0 +1,58 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class DvcliveLoggerHook(LoggerHook):
+    """Class to log metrics with dvclive.
+
+    It requires `dvclive`_ to be installed.
+
+    Args:
+        path (str): Directory where dvclive will write TSV log files.
+        interval (int): Logging interval (every k iterations).
+            Default 10.
+        ignore_last (bool): Ignore the log of last iterations in each epoch
+            if less than `interval`.
+            Default: True.
+        reset_flag (bool): Whether to clear the output buffer after logging.
+            Default: True.
+        by_epoch (bool): Whether EpochBasedRunner is used.
+            Default: True.
+
+    .. _dvclive:
+        https://dvc.org/doc/dvclive
+    """
+
+    def __init__(self,
+                 path,
+                 interval=10,
+                 ignore_last=True,
+                 reset_flag=True,
+                 by_epoch=True):
+
+        super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
+                                                reset_flag, by_epoch)
+        self.path = path
+        self.import_dvclive()
+
+    def import_dvclive(self):
+        try:
+            import dvclive
+        except ImportError:
+            raise ImportError(
+                'Please run "pip install dvclive" to install dvclive')
+        self.dvclive = dvclive
+
+    @master_only
+    def before_run(self, runner):
+        self.dvclive.init(self.path)
+
+    @master_only
+    def log(self, runner):
+        tags = self.get_loggable_tags(runner)
+        if tags:
+            for k, v in tags.items():
+                self.dvclive.log(k, v, step=self.get_iter(runner))
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py b/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a72592be47b534ce22573775fd5a7e8e86d72d
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/mlflow.py
@@ -0,0 +1,78 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class MlflowLoggerHook(LoggerHook):
+
+    def __init__(self,
+                 exp_name=None,
+                 tags=None,
+                 log_model=True,
+                 interval=10,
+                 ignore_last=True,
+                 reset_flag=False,
+                 by_epoch=True):
+        """Class to log metrics and (optionally) a trained model to MLflow.
+
+        It requires `MLflow`_ to be installed.
+
+        Args:
+            exp_name (str, optional): Name of the experiment to be used.
+                Default None.
+                If not None, set the active experiment.
+                If experiment does not exist, an experiment with provided name
+                will be created.
+            tags (dict of str: str, optional): Tags for the current run.
+                Default None.
+                If not None, set tags for the current run.
+            log_model (bool, optional): Whether to log an MLflow artifact.
+                Default True.
+                If True, log runner.model as an MLflow artifact
+                for the current run.
+            interval (int): Logging interval (every k iterations).
+            ignore_last (bool): Ignore the log of last iterations in each epoch
+                if less than `interval`.
+            reset_flag (bool): Whether to clear the output buffer after logging
+            by_epoch (bool): Whether EpochBasedRunner is used.
+
+        .. _MLflow:
+            https://www.mlflow.org/docs/latest/index.html
+        """
+        super(MlflowLoggerHook, self).__init__(interval, ignore_last,
+                                               reset_flag, by_epoch)
+        self.import_mlflow()
+        self.exp_name = exp_name
+        self.tags = tags
+        self.log_model = log_model
+
+    def import_mlflow(self):
+        try:
+            import mlflow
+            import mlflow.pytorch as mlflow_pytorch
+        except ImportError:
+            raise ImportError(
+                'Please run "pip install mlflow" to install mlflow')
+        self.mlflow = mlflow
+        self.mlflow_pytorch = mlflow_pytorch
+
+    @master_only
+    def before_run(self, runner):
+        super(MlflowLoggerHook, self).before_run(runner)
+        if self.exp_name is not None:
+            self.mlflow.set_experiment(self.exp_name)
+        if self.tags is not None:
+            self.mlflow.set_tags(self.tags)
+
+    @master_only
+    def log(self, runner):
+        tags = self.get_loggable_tags(runner)
+        if tags:
+            self.mlflow.log_metrics(tags, step=self.get_iter(runner))
+
+    @master_only
+    def after_run(self, runner):
+        if self.log_model:
+            self.mlflow_pytorch.log_model(runner.model, 'models')
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py b/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a38772b0c93a8608f32c6357b8616e77c139dc9
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/neptune.py
@@ -0,0 +1,82 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class NeptuneLoggerHook(LoggerHook):
+    """Class to log metrics to NeptuneAI.
+
+    It requires `neptune-client` to be installed.
+
+    Args:
+        init_kwargs (dict): a dict contains the initialization keys as below:
+            - project (str): Name of a project in a form of
+                namespace/project_name. If None, the value of
+                NEPTUNE_PROJECT environment variable will be taken.
+            - api_token (str): User’s API token.
+                If None, the value of NEPTUNE_API_TOKEN environment
+                variable will be taken. Note: It is strongly recommended
+                to use NEPTUNE_API_TOKEN environment variable rather than
+                placing your API token in plain text in your source code.
+            - name (str, optional, default is 'Untitled'): Editable name of
+                the run. Name is displayed in the run's Details and in
+                Runs table as a column.
+            Check https://docs.neptune.ai/api-reference/neptune#init for
+                more init arguments.
+        interval (int): Logging interval (every k iterations).
+        ignore_last (bool): Ignore the log of last iterations in each epoch
+            if less than `interval`.
+        reset_flag (bool): Whether to clear the output buffer after logging
+        by_epoch (bool): Whether EpochBasedRunner is used.
+
+    .. _NeptuneAI:
+        https://docs.neptune.ai/you-should-know/logging-metadata
+    """
+
+    def __init__(self,
+                 init_kwargs=None,
+                 interval=10,
+                 ignore_last=True,
+                 reset_flag=True,
+                 with_step=True,
+                 by_epoch=True):
+
+        super(NeptuneLoggerHook, self).__init__(interval, ignore_last,
+                                                reset_flag, by_epoch)
+        self.import_neptune()
+        self.init_kwargs = init_kwargs
+        self.with_step = with_step
+
+    def import_neptune(self):
+        try:
+            import neptune.new as neptune
+        except ImportError:
+            raise ImportError(
+                'Please run "pip install neptune-client" to install neptune')
+        self.neptune = neptune
+        self.run = None
+
+    @master_only
+    def before_run(self, runner):
+        if self.init_kwargs:
+            self.run = self.neptune.init(**self.init_kwargs)
+        else:
+            self.run = self.neptune.init()
+
+    @master_only
+    def log(self, runner):
+        tags = self.get_loggable_tags(runner)
+        if tags:
+            for tag_name, tag_value in tags.items():
+                if self.with_step:
+                    self.run[tag_name].log(
+                        tag_value, step=self.get_iter(runner))
+                else:
+                    tags['global_step'] = self.get_iter(runner)
+                    self.run[tag_name].log(tags)
+
+    @master_only
+    def after_run(self, runner):
+        self.run.stop()
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py b/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dcf146d8163aff1363e9764999b0a74d674a595
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/pavi.py
@@ -0,0 +1,117 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import json
+import os
+import os.path as osp
+
+import torch
+import yaml
+
+import annotator.uniformer.mmcv as mmcv
+from ....parallel.utils import is_module_wrapper
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class PaviLoggerHook(LoggerHook):
+
+    def __init__(self,
+                 init_kwargs=None,
+                 add_graph=False,
+                 add_last_ckpt=False,
+                 interval=10,
+                 ignore_last=True,
+                 reset_flag=False,
+                 by_epoch=True,
+                 img_key='img_info'):
+        super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+                                             by_epoch)
+        self.init_kwargs = init_kwargs
+        self.add_graph = add_graph
+        self.add_last_ckpt = add_last_ckpt
+        self.img_key = img_key
+
+    @master_only
+    def before_run(self, runner):
+        super(PaviLoggerHook, self).before_run(runner)
+        try:
+            from pavi import SummaryWriter
+        except ImportError:
+            raise ImportError('Please run "pip install pavi" to install pavi.')
+
+        self.run_name = runner.work_dir.split('/')[-1]
+
+        if not self.init_kwargs:
+            self.init_kwargs = dict()
+        self.init_kwargs['name'] = self.run_name
+        self.init_kwargs['model'] = runner._model_name
+        if runner.meta is not None:
+            if 'config_dict' in runner.meta:
+                config_dict = runner.meta['config_dict']
+                assert isinstance(
+                    config_dict,
+                    dict), ('meta["config_dict"] has to be of a dict, '
+                            f'but got {type(config_dict)}')
+            elif 'config_file' in runner.meta:
+                config_file = runner.meta['config_file']
+                config_dict = dict(mmcv.Config.fromfile(config_file))
+            else:
+                config_dict = None
+            if config_dict is not None:
+                # 'max_.*iter' is parsed in pavi sdk as the maximum iterations
+                #  to properly set up the progress bar.
+                config_dict = config_dict.copy()
+                config_dict.setdefault('max_iter', runner.max_iters)
+                # non-serializable values are first converted in
+                # mmcv.dump to json
+                config_dict = json.loads(
+                    mmcv.dump(config_dict, file_format='json'))
+                session_text = yaml.dump(config_dict)
+                self.init_kwargs['session_text'] = session_text
+        self.writer = SummaryWriter(**self.init_kwargs)
+
+    def get_step(self, runner):
+        """Get the total training step/epoch."""
+        if self.get_mode(runner) == 'val' and self.by_epoch:
+            return self.get_epoch(runner)
+        else:
+            return self.get_iter(runner)
+
+    @master_only
+    def log(self, runner):
+        tags = self.get_loggable_tags(runner, add_mode=False)
+        if tags:
+            self.writer.add_scalars(
+                self.get_mode(runner), tags, self.get_step(runner))
+
+    @master_only
+    def after_run(self, runner):
+        if self.add_last_ckpt:
+            ckpt_path = osp.join(runner.work_dir, 'latest.pth')
+            if osp.islink(ckpt_path):
+                ckpt_path = osp.join(runner.work_dir, os.readlink(ckpt_path))
+
+            if osp.isfile(ckpt_path):
+                # runner.epoch += 1 has been done before `after_run`.
+                iteration = runner.epoch if self.by_epoch else runner.iter
+                return self.writer.add_snapshot_file(
+                    tag=self.run_name,
+                    snapshot_file_path=ckpt_path,
+                    iteration=iteration)
+
+        # flush the buffer and send a task ending signal to Pavi
+        self.writer.close()
+
+    @master_only
+    def before_epoch(self, runner):
+        if runner.epoch == 0 and self.add_graph:
+            if is_module_wrapper(runner.model):
+                _model = runner.model.module
+            else:
+                _model = runner.model
+            device = next(_model.parameters()).device
+            data = next(iter(runner.data_loader))
+            image = data[self.img_key][0:1].to(device)
+            with torch.no_grad():
+                self.writer.add_graph(_model, image)
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py b/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dd5011dc08def6c09eef86d3ce5b124c9fc5372
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/tensorboard.py
@@ -0,0 +1,57 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, digit_version
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class TensorboardLoggerHook(LoggerHook):
+
+    def __init__(self,
+                 log_dir=None,
+                 interval=10,
+                 ignore_last=True,
+                 reset_flag=False,
+                 by_epoch=True):
+        super(TensorboardLoggerHook, self).__init__(interval, ignore_last,
+                                                    reset_flag, by_epoch)
+        self.log_dir = log_dir
+
+    @master_only
+    def before_run(self, runner):
+        super(TensorboardLoggerHook, self).before_run(runner)
+        if (TORCH_VERSION == 'parrots'
+                or digit_version(TORCH_VERSION) < digit_version('1.1')):
+            try:
+                from tensorboardX import SummaryWriter
+            except ImportError:
+                raise ImportError('Please install tensorboardX to use '
+                                  'TensorboardLoggerHook.')
+        else:
+            try:
+                from torch.utils.tensorboard import SummaryWriter
+            except ImportError:
+                raise ImportError(
+                    'Please run "pip install future tensorboard" to install '
+                    'the dependencies to use torch.utils.tensorboard '
+                    '(applicable to PyTorch 1.1 or higher)')
+
+        if self.log_dir is None:
+            self.log_dir = osp.join(runner.work_dir, 'tf_logs')
+        self.writer = SummaryWriter(self.log_dir)
+
+    @master_only
+    def log(self, runner):
+        tags = self.get_loggable_tags(runner, allow_text=True)
+        for tag, val in tags.items():
+            if isinstance(val, str):
+                self.writer.add_text(tag, val, self.get_iter(runner))
+            else:
+                self.writer.add_scalar(tag, val, self.get_iter(runner))
+
+    @master_only
+    def after_run(self, runner):
+        self.writer.close()
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/text.py b/annotator/uniformer/mmcv/runner/hooks/logger/text.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b1a3eca9595a130121526f8b4c29915387ab35
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/text.py
@@ -0,0 +1,256 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.fileio.file_client import FileClient
+from annotator.uniformer.mmcv.utils import is_tuple_of, scandir
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class TextLoggerHook(LoggerHook):
+    """Logger hook in text.
+
+    In this logger hook, the information will be printed on terminal and
+    saved in json file.
+
+    Args:
+        by_epoch (bool, optional): Whether EpochBasedRunner is used.
+            Default: True.
+        interval (int, optional): Logging interval (every k iterations).
+            Default: 10.
+        ignore_last (bool, optional): Ignore the log of last iterations in each
+            epoch if less than :attr:`interval`. Default: True.
+        reset_flag (bool, optional): Whether to clear the output buffer after
+            logging. Default: False.
+        interval_exp_name (int, optional): Logging interval for experiment
+            name. This feature is to help users conveniently get the experiment
+            information from screen or log file. Default: 1000.
+        out_dir (str, optional): Logs are saved in ``runner.work_dir`` default.
+            If ``out_dir`` is specified, logs will be copied to a new directory
+            which is the concatenation of ``out_dir`` and the last level
+            directory of ``runner.work_dir``. Default: None.
+            `New in version 1.3.16.`
+        out_suffix (str or tuple[str], optional): Those filenames ending with
+            ``out_suffix`` will be copied to ``out_dir``.
+            Default: ('.log.json', '.log', '.py').
+            `New in version 1.3.16.`
+        keep_local (bool, optional): Whether to keep local log when
+            :attr:`out_dir` is specified. If False, the local log will be
+            removed. Default: True.
+            `New in version 1.3.16.`
+        file_client_args (dict, optional): Arguments to instantiate a
+            FileClient. See :class:`mmcv.fileio.FileClient` for details.
+            Default: None.
+            `New in version 1.3.16.`
+    """
+
+    def __init__(self,
+                 by_epoch=True,
+                 interval=10,
+                 ignore_last=True,
+                 reset_flag=False,
+                 interval_exp_name=1000,
+                 out_dir=None,
+                 out_suffix=('.log.json', '.log', '.py'),
+                 keep_local=True,
+                 file_client_args=None):
+        super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag,
+                                             by_epoch)
+        self.by_epoch = by_epoch
+        self.time_sec_tot = 0
+        self.interval_exp_name = interval_exp_name
+
+        if out_dir is None and file_client_args is not None:
+            raise ValueError(
+                'file_client_args should be "None" when `out_dir` is not'
+                'specified.')
+        self.out_dir = out_dir
+
+        if not (out_dir is None or isinstance(out_dir, str)
+                or is_tuple_of(out_dir, str)):
+            raise TypeError('out_dir should be  "None" or string or tuple of '
+                            'string, but got {out_dir}')
+        self.out_suffix = out_suffix
+
+        self.keep_local = keep_local
+        self.file_client_args = file_client_args
+        if self.out_dir is not None:
+            self.file_client = FileClient.infer_client(file_client_args,
+                                                       self.out_dir)
+
+    def before_run(self, runner):
+        super(TextLoggerHook, self).before_run(runner)
+
+        if self.out_dir is not None:
+            self.file_client = FileClient.infer_client(self.file_client_args,
+                                                       self.out_dir)
+            # The final `self.out_dir` is the concatenation of `self.out_dir`
+            # and the last level directory of `runner.work_dir`
+            basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+            self.out_dir = self.file_client.join_path(self.out_dir, basename)
+            runner.logger.info(
+                (f'Text logs will be saved to {self.out_dir} by '
+                 f'{self.file_client.name} after the training process.'))
+
+        self.start_iter = runner.iter
+        self.json_log_path = osp.join(runner.work_dir,
+                                      f'{runner.timestamp}.log.json')
+        if runner.meta is not None:
+            self._dump_log(runner.meta, runner)
+
+    def _get_max_memory(self, runner):
+        device = getattr(runner.model, 'output_device', None)
+        mem = torch.cuda.max_memory_allocated(device=device)
+        mem_mb = torch.tensor([mem / (1024 * 1024)],
+                              dtype=torch.int,
+                              device=device)
+        if runner.world_size > 1:
+            dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
+        return mem_mb.item()
+
+    def _log_info(self, log_dict, runner):
+        # print exp name for users to distinguish experiments
+        # at every ``interval_exp_name`` iterations and the end of each epoch
+        if runner.meta is not None and 'exp_name' in runner.meta:
+            if (self.every_n_iters(runner, self.interval_exp_name)) or (
+                    self.by_epoch and self.end_of_epoch(runner)):
+                exp_info = f'Exp name: {runner.meta["exp_name"]}'
+                runner.logger.info(exp_info)
+
+        if log_dict['mode'] == 'train':
+            if isinstance(log_dict['lr'], dict):
+                lr_str = []
+                for k, val in log_dict['lr'].items():
+                    lr_str.append(f'lr_{k}: {val:.3e}')
+                lr_str = ' '.join(lr_str)
+            else:
+                lr_str = f'lr: {log_dict["lr"]:.3e}'
+
+            # by epoch: Epoch [4][100/1000]
+            # by iter:  Iter [100/100000]
+            if self.by_epoch:
+                log_str = f'Epoch [{log_dict["epoch"]}]' \
+                          f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
+            else:
+                log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
+            log_str += f'{lr_str}, '
+
+            if 'time' in log_dict.keys():
+                self.time_sec_tot += (log_dict['time'] * self.interval)
+                time_sec_avg = self.time_sec_tot / (
+                    runner.iter - self.start_iter + 1)
+                eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
+                eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+                log_str += f'eta: {eta_str}, '
+                log_str += f'time: {log_dict["time"]:.3f}, ' \
+                           f'data_time: {log_dict["data_time"]:.3f}, '
+                # statistic memory
+                if torch.cuda.is_available():
+                    log_str += f'memory: {log_dict["memory"]}, '
+        else:
+            # val/test time
+            # here 1000 is the length of the val dataloader
+            # by epoch: Epoch[val] [4][1000]
+            # by iter: Iter[val] [1000]
+            if self.by_epoch:
+                log_str = f'Epoch({log_dict["mode"]}) ' \
+                    f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
+            else:
+                log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'
+
+        log_items = []
+        for name, val in log_dict.items():
+            # TODO: resolve this hack
+            # these items have been in log_str
+            if name in [
+                    'mode', 'Epoch', 'iter', 'lr', 'time', 'data_time',
+                    'memory', 'epoch'
+            ]:
+                continue
+            if isinstance(val, float):
+                val = f'{val:.4f}'
+            log_items.append(f'{name}: {val}')
+        log_str += ', '.join(log_items)
+
+        runner.logger.info(log_str)
+
+    def _dump_log(self, log_dict, runner):
+        # dump log in json format
+        json_log = OrderedDict()
+        for k, v in log_dict.items():
+            json_log[k] = self._round_float(v)
+        # only append log at last line
+        if runner.rank == 0:
+            with open(self.json_log_path, 'a+') as f:
+                mmcv.dump(json_log, f, file_format='json')
+                f.write('\n')
+
+    def _round_float(self, items):
+        if isinstance(items, list):
+            return [self._round_float(item) for item in items]
+        elif isinstance(items, float):
+            return round(items, 5)
+        else:
+            return items
+
+    def log(self, runner):
+        if 'eval_iter_num' in runner.log_buffer.output:
+            # this doesn't modify runner.iter and is regardless of by_epoch
+            cur_iter = runner.log_buffer.output.pop('eval_iter_num')
+        else:
+            cur_iter = self.get_iter(runner, inner_iter=True)
+
+        log_dict = OrderedDict(
+            mode=self.get_mode(runner),
+            epoch=self.get_epoch(runner),
+            iter=cur_iter)
+
+        # only record lr of the first param group
+        cur_lr = runner.current_lr()
+        if isinstance(cur_lr, list):
+            log_dict['lr'] = cur_lr[0]
+        else:
+            assert isinstance(cur_lr, dict)
+            log_dict['lr'] = {}
+            for k, lr_ in cur_lr.items():
+                assert isinstance(lr_, list)
+                log_dict['lr'].update({k: lr_[0]})
+
+        if 'time' in runner.log_buffer.output:
+            # statistic memory
+            if torch.cuda.is_available():
+                log_dict['memory'] = self._get_max_memory(runner)
+
+        log_dict = dict(log_dict, **runner.log_buffer.output)
+
+        self._log_info(log_dict, runner)
+        self._dump_log(log_dict, runner)
+        return log_dict
+
+    def after_run(self, runner):
+        # copy or upload logs to self.out_dir
+        if self.out_dir is not None:
+            for filename in scandir(runner.work_dir, self.out_suffix, True):
+                local_filepath = osp.join(runner.work_dir, filename)
+                out_filepath = self.file_client.join_path(
+                    self.out_dir, filename)
+                with open(local_filepath, 'r') as f:
+                    self.file_client.put_text(f.read(), out_filepath)
+
+                runner.logger.info(
+                    (f'The file {local_filepath} has been uploaded to '
+                     f'{out_filepath}.'))
+
+                if not self.keep_local:
+                    os.remove(local_filepath)
+                    runner.logger.info(
+                        (f'{local_filepath} was removed due to the '
+                         '`self.keep_local=False`'))
diff --git a/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py b/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f6808462eb79ab2b04806a5d9f0d3dd079b5ea9
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/logger/wandb.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ...dist_utils import master_only
+from ..hook import HOOKS
+from .base import LoggerHook
+
+
+@HOOKS.register_module()
+class WandbLoggerHook(LoggerHook):
+
+    def __init__(self,
+                 init_kwargs=None,
+                 interval=10,
+                 ignore_last=True,
+                 reset_flag=False,
+                 commit=True,
+                 by_epoch=True,
+                 with_step=True):
+        super(WandbLoggerHook, self).__init__(interval, ignore_last,
+                                              reset_flag, by_epoch)
+        self.import_wandb()
+        self.init_kwargs = init_kwargs
+        self.commit = commit
+        self.with_step = with_step
+
+    def import_wandb(self):
+        try:
+            import wandb
+        except ImportError:
+            raise ImportError(
+                'Please run "pip install wandb" to install wandb')
+        self.wandb = wandb
+
+    @master_only
+    def before_run(self, runner):
+        super(WandbLoggerHook, self).before_run(runner)
+        if self.wandb is None:
+            self.import_wandb()
+        if self.init_kwargs:
+            self.wandb.init(**self.init_kwargs)
+        else:
+            self.wandb.init()
+
+    @master_only
+    def log(self, runner):
+        tags = self.get_loggable_tags(runner)
+        if tags:
+            if self.with_step:
+                self.wandb.log(
+                    tags, step=self.get_iter(runner), commit=self.commit)
+            else:
+                tags['global_step'] = self.get_iter(runner)
+                self.wandb.log(tags, commit=self.commit)
+
+    @master_only
+    def after_run(self, runner):
+        self.wandb.join()
diff --git a/annotator/uniformer/mmcv/runner/hooks/lr_updater.py b/annotator/uniformer/mmcv/runner/hooks/lr_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..6365908ddf6070086de2ffc0afada46ed2f32256
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/lr_updater.py
@@ -0,0 +1,670 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numbers
+from math import cos, pi
+
+import annotator.uniformer.mmcv as mmcv
+from .hook import HOOKS, Hook
+
+
+class LrUpdaterHook(Hook):
+    """LR Scheduler in MMCV.
+
+    Args:
+        by_epoch (bool): LR changes epoch by epoch
+        warmup (string): Type of warmup used. It can be None(use no warmup),
+            'constant', 'linear' or 'exp'
+        warmup_iters (int): The number of iterations or epochs that warmup
+            lasts
+        warmup_ratio (float): LR used at the beginning of warmup equals to
+            warmup_ratio * initial_lr
+        warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters
+            means the number of epochs that warmup lasts, otherwise means the
+            number of iteration that warmup lasts
+    """
+
+    def __init__(self,
+                 by_epoch=True,
+                 warmup=None,
+                 warmup_iters=0,
+                 warmup_ratio=0.1,
+                 warmup_by_epoch=False):
+        # validate the "warmup" argument
+        if warmup is not None:
+            if warmup not in ['constant', 'linear', 'exp']:
+                raise ValueError(
+                    f'"{warmup}" is not a supported type for warming up, valid'
+                    ' types are "constant" and "linear"')
+        if warmup is not None:
+            assert warmup_iters > 0, \
+                '"warmup_iters" must be a positive integer'
+            assert 0 < warmup_ratio <= 1.0, \
+                '"warmup_ratio" must be in range (0,1]'
+
+        self.by_epoch = by_epoch
+        self.warmup = warmup
+        self.warmup_iters = warmup_iters
+        self.warmup_ratio = warmup_ratio
+        self.warmup_by_epoch = warmup_by_epoch
+
+        if self.warmup_by_epoch:
+            self.warmup_epochs = self.warmup_iters
+            self.warmup_iters = None
+        else:
+            self.warmup_epochs = None
+
+        self.base_lr = []  # initial lr for all param groups
+        self.regular_lr = []  # expected lr if no warming up is performed
+
+    def _set_lr(self, runner, lr_groups):
+        if isinstance(runner.optimizer, dict):
+            for k, optim in runner.optimizer.items():
+                for param_group, lr in zip(optim.param_groups, lr_groups[k]):
+                    param_group['lr'] = lr
+        else:
+            for param_group, lr in zip(runner.optimizer.param_groups,
+                                       lr_groups):
+                param_group['lr'] = lr
+
+    def get_lr(self, runner, base_lr):
+        raise NotImplementedError
+
+    def get_regular_lr(self, runner):
+        if isinstance(runner.optimizer, dict):
+            lr_groups = {}
+            for k in runner.optimizer.keys():
+                _lr_group = [
+                    self.get_lr(runner, _base_lr)
+                    for _base_lr in self.base_lr[k]
+                ]
+                lr_groups.update({k: _lr_group})
+
+            return lr_groups
+        else:
+            return [self.get_lr(runner, _base_lr) for _base_lr in self.base_lr]
+
+    def get_warmup_lr(self, cur_iters):
+
+        def _get_warmup_lr(cur_iters, regular_lr):
+            if self.warmup == 'constant':
+                warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
+            elif self.warmup == 'linear':
+                k = (1 - cur_iters / self.warmup_iters) * (1 -
+                                                           self.warmup_ratio)
+                warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
+            elif self.warmup == 'exp':
+                k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+                warmup_lr = [_lr * k for _lr in regular_lr]
+            return warmup_lr
+
+        if isinstance(self.regular_lr, dict):
+            lr_groups = {}
+            for key, regular_lr in self.regular_lr.items():
+                lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
+            return lr_groups
+        else:
+            return _get_warmup_lr(cur_iters, self.regular_lr)
+
+    def before_run(self, runner):
+        # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
+        # it will be set according to the optimizer params
+        if isinstance(runner.optimizer, dict):
+            self.base_lr = {}
+            for k, optim in runner.optimizer.items():
+                for group in optim.param_groups:
+                    group.setdefault('initial_lr', group['lr'])
+                _base_lr = [
+                    group['initial_lr'] for group in optim.param_groups
+                ]
+                self.base_lr.update({k: _base_lr})
+        else:
+            for group in runner.optimizer.param_groups:
+                group.setdefault('initial_lr', group['lr'])
+            self.base_lr = [
+                group['initial_lr'] for group in runner.optimizer.param_groups
+            ]
+
+    def before_train_epoch(self, runner):
+        if self.warmup_iters is None:
+            epoch_len = len(runner.data_loader)
+            self.warmup_iters = self.warmup_epochs * epoch_len
+
+        if not self.by_epoch:
+            return
+
+        self.regular_lr = self.get_regular_lr(runner)
+        self._set_lr(runner, self.regular_lr)
+
+    def before_train_iter(self, runner):
+        cur_iter = runner.iter
+        if not self.by_epoch:
+            self.regular_lr = self.get_regular_lr(runner)
+            if self.warmup is None or cur_iter >= self.warmup_iters:
+                self._set_lr(runner, self.regular_lr)
+            else:
+                warmup_lr = self.get_warmup_lr(cur_iter)
+                self._set_lr(runner, warmup_lr)
+        elif self.by_epoch:
+            if self.warmup is None or cur_iter > self.warmup_iters:
+                return
+            elif cur_iter == self.warmup_iters:
+                self._set_lr(runner, self.regular_lr)
+            else:
+                warmup_lr = self.get_warmup_lr(cur_iter)
+                self._set_lr(runner, warmup_lr)
+
+
+@HOOKS.register_module()
+class FixedLrUpdaterHook(LrUpdaterHook):
+
+    def __init__(self, **kwargs):
+        super(FixedLrUpdaterHook, self).__init__(**kwargs)
+
+    def get_lr(self, runner, base_lr):
+        return base_lr
+
+
+@HOOKS.register_module()
+class StepLrUpdaterHook(LrUpdaterHook):
+    """Step LR scheduler with min_lr clipping.
+
+    Args:
+        step (int | list[int]): Step to decay the LR. If an int value is given,
+            regard it as the decay interval. If a list is given, decay LR at
+            these steps.
+        gamma (float, optional): Decay LR ratio. Default: 0.1.
+        min_lr (float, optional): Minimum LR value to keep. If LR after decay
+            is lower than `min_lr`, it will be clipped to this value. If None
+            is given, we don't perform lr clipping. Default: None.
+    """
+
+    def __init__(self, step, gamma=0.1, min_lr=None, **kwargs):
+        if isinstance(step, list):
+            assert mmcv.is_list_of(step, int)
+            assert all([s > 0 for s in step])
+        elif isinstance(step, int):
+            assert step > 0
+        else:
+            raise TypeError('"step" must be a list or integer')
+        self.step = step
+        self.gamma = gamma
+        self.min_lr = min_lr
+        super(StepLrUpdaterHook, self).__init__(**kwargs)
+
+    def get_lr(self, runner, base_lr):
+        progress = runner.epoch if self.by_epoch else runner.iter
+
+        # calculate exponential term
+        if isinstance(self.step, int):
+            exp = progress // self.step
+        else:
+            exp = len(self.step)
+            for i, s in enumerate(self.step):
+                if progress < s:
+                    exp = i
+                    break
+
+        lr = base_lr * (self.gamma**exp)
+        if self.min_lr is not None:
+            # clip to a minimum value
+            lr = max(lr, self.min_lr)
+        return lr
+
+
+@HOOKS.register_module()
+class ExpLrUpdaterHook(LrUpdaterHook):
+
+    def __init__(self, gamma, **kwargs):
+        self.gamma = gamma
+        super(ExpLrUpdaterHook, self).__init__(**kwargs)
+
+    def get_lr(self, runner, base_lr):
+        progress = runner.epoch if self.by_epoch else runner.iter
+        return base_lr * self.gamma**progress
+
+
+@HOOKS.register_module()
+class PolyLrUpdaterHook(LrUpdaterHook):
+
+    def __init__(self, power=1., min_lr=0., **kwargs):
+        self.power = power
+        self.min_lr = min_lr
+        super(PolyLrUpdaterHook, self).__init__(**kwargs)
+
+    def get_lr(self, runner, base_lr):
+        if self.by_epoch:
+            progress = runner.epoch
+            max_progress = runner.max_epochs
+        else:
+            progress = runner.iter
+            max_progress = runner.max_iters
+        coeff = (1 - progress / max_progress)**self.power
+        return (base_lr - self.min_lr) * coeff + self.min_lr
+
+
+@HOOKS.register_module()
+class InvLrUpdaterHook(LrUpdaterHook):
+
+    def __init__(self, gamma, power=1., **kwargs):
+        self.gamma = gamma
+        self.power = power
+        super(InvLrUpdaterHook, self).__init__(**kwargs)
+
+    def get_lr(self, runner, base_lr):
+        progress = runner.epoch if self.by_epoch else runner.iter
+        return base_lr * (1 + self.gamma * progress)**(-self.power)
+
+
+@HOOKS.register_module()
+class CosineAnnealingLrUpdaterHook(LrUpdaterHook):
+
+    def __init__(self, min_lr=None, min_lr_ratio=None, **kwargs):
+        assert (min_lr is None) ^ (min_lr_ratio is None)
+        self.min_lr = min_lr
+        self.min_lr_ratio = min_lr_ratio
+        super(CosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+
+    def get_lr(self, runner, base_lr):
+        if self.by_epoch:
+            progress = runner.epoch
+            max_progress = runner.max_epochs
+        else:
+            progress = runner.iter
+            max_progress = runner.max_iters
+
+        if self.min_lr_ratio is not None:
+            target_lr = base_lr * self.min_lr_ratio
+        else:
+            target_lr = self.min_lr
+        return annealing_cos(base_lr, target_lr, progress / max_progress)
+
+
+@HOOKS.register_module()
+class FlatCosineAnnealingLrUpdaterHook(LrUpdaterHook):
+    """Flat + Cosine lr schedule.
+
+    Modified from https://github.com/fastai/fastai/blob/master/fastai/callback/schedule.py#L128 # noqa: E501
+
+    Args:
+        start_percent (float): When to start annealing the learning rate
+            after the percentage of the total training steps.
+            The value should be in range [0, 1).
+            Default: 0.75
+        min_lr (float, optional): The minimum lr. Default: None.
+        min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+            Either `min_lr` or `min_lr_ratio` should be specified.
+            Default: None.
+    """
+
+    def __init__(self,
+                 start_percent=0.75,
+                 min_lr=None,
+                 min_lr_ratio=None,
+                 **kwargs):
+        assert (min_lr is None) ^ (min_lr_ratio is None)
+        if start_percent < 0 or start_percent > 1 or not isinstance(
+                start_percent, float):
+            raise ValueError(
+                'expected float between 0 and 1 start_percent, but '
+                f'got {start_percent}')
+        self.start_percent = start_percent
+        self.min_lr = min_lr
+        self.min_lr_ratio = min_lr_ratio
+        super(FlatCosineAnnealingLrUpdaterHook, self).__init__(**kwargs)
+
+    def get_lr(self, runner, base_lr):
+        if self.by_epoch:
+            start = round(runner.max_epochs * self.start_percent)
+            progress = runner.epoch - start
+            max_progress = runner.max_epochs - start
+        else:
+            start = round(runner.max_iters * self.start_percent)
+            progress = runner.iter - start
+            max_progress = runner.max_iters - start
+
+        if self.min_lr_ratio is not None:
+            target_lr = base_lr * self.min_lr_ratio
+        else:
+            target_lr = self.min_lr
+
+        if progress < 0:
+            return base_lr
+        else:
+            return annealing_cos(base_lr, target_lr, progress / max_progress)
+
+
+@HOOKS.register_module()
+class CosineRestartLrUpdaterHook(LrUpdaterHook):
+    """Cosine annealing with restarts learning rate scheme.
+
+    Args:
+        periods (list[int]): Periods for each cosine anneling cycle.
+        restart_weights (list[float], optional): Restart weights at each
+            restart iteration. Default: [1].
+        min_lr (float, optional): The minimum lr. Default: None.
+        min_lr_ratio (float, optional): The ratio of minimum lr to the base lr.
+            Either `min_lr` or `min_lr_ratio` should be specified.
+            Default: None.
+    """
+
+    def __init__(self,
+                 periods,
+                 restart_weights=[1],
+                 min_lr=None,
+                 min_lr_ratio=None,
+                 **kwargs):
+        assert (min_lr is None) ^ (min_lr_ratio is None)
+        self.periods = periods
+        self.min_lr = min_lr
+        self.min_lr_ratio = min_lr_ratio
+        self.restart_weights = restart_weights
+        assert (len(self.periods) == len(self.restart_weights)
+                ), 'periods and restart_weights should have the same length.'
+        super(CosineRestartLrUpdaterHook, self).__init__(**kwargs)
+
+        self.cumulative_periods = [
+            sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
+        ]
+
+    def get_lr(self, runner, base_lr):
+        if self.by_epoch:
+            progress = runner.epoch
+        else:
+            progress = runner.iter
+
+        if self.min_lr_ratio is not None:
+            target_lr = base_lr * self.min_lr_ratio
+        else:
+            target_lr = self.min_lr
+
+        idx = get_position_from_periods(progress, self.cumulative_periods)
+        current_weight = self.restart_weights[idx]
+        nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
+        current_periods = self.periods[idx]
+
+        alpha = min((progress - nearest_restart) / current_periods, 1)
+        return annealing_cos(base_lr, target_lr, alpha, current_weight)
+
+
+def get_position_from_periods(iteration, cumulative_periods):
+    """Get the position from a period list.
+
+    It will return the index of the right-closest number in the period list.
+    For example, the cumulative_periods = [100, 200, 300, 400],
+    if iteration == 50, return 0;
+    if iteration == 210, return 2;
+    if iteration == 300, return 3.
+
+    Args:
+        iteration (int): Current iteration.
+        cumulative_periods (list[int]): Cumulative period list.
+
+    Returns:
+        int: The position of the right-closest number in the period list.
+    """
+    for i, period in enumerate(cumulative_periods):
+        if iteration < period:
+            return i
+    raise ValueError(f'Current iteration {iteration} exceeds '
+                     f'cumulative_periods {cumulative_periods}')
+
+
+@HOOKS.register_module()
+class CyclicLrUpdaterHook(LrUpdaterHook):
+    """Cyclic LR Scheduler.
+
+    Implement the cyclical learning rate policy (CLR) described in
+    https://arxiv.org/pdf/1506.01186.pdf
+
+    Different from the original paper, we use cosine annealing rather than
+    triangular policy inside a cycle. This improves the performance in the
+    3D detection area.
+
+    Args:
+        by_epoch (bool): Whether to update LR by epoch.
+        target_ratio (tuple[float]): Relative ratio of the highest LR and the
+            lowest LR to the initial LR.
+        cyclic_times (int): Number of cycles during training
+        step_ratio_up (float): The ratio of the increasing process of LR in
+            the total cycle.
+        anneal_strategy (str): {'cos', 'linear'}
+            Specifies the annealing strategy: 'cos' for cosine annealing,
+            'linear' for linear annealing. Default: 'cos'.
+    """
+
+    def __init__(self,
+                 by_epoch=False,
+                 target_ratio=(10, 1e-4),
+                 cyclic_times=1,
+                 step_ratio_up=0.4,
+                 anneal_strategy='cos',
+                 **kwargs):
+        if isinstance(target_ratio, float):
+            target_ratio = (target_ratio, target_ratio / 1e5)
+        elif isinstance(target_ratio, tuple):
+            target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+                if len(target_ratio) == 1 else target_ratio
+        else:
+            raise ValueError('target_ratio should be either float '
+                             f'or tuple, got {type(target_ratio)}')
+
+        assert len(target_ratio) == 2, \
+            '"target_ratio" must be list or tuple of two floats'
+        assert 0 <= step_ratio_up < 1.0, \
+            '"step_ratio_up" must be in range [0,1)'
+
+        self.target_ratio = target_ratio
+        self.cyclic_times = cyclic_times
+        self.step_ratio_up = step_ratio_up
+        self.lr_phases = []  # init lr_phases
+        # validate anneal_strategy
+        if anneal_strategy not in ['cos', 'linear']:
+            raise ValueError('anneal_strategy must be one of "cos" or '
+                             f'"linear", instead got {anneal_strategy}')
+        elif anneal_strategy == 'cos':
+            self.anneal_func = annealing_cos
+        elif anneal_strategy == 'linear':
+            self.anneal_func = annealing_linear
+
+        assert not by_epoch, \
+            'currently only support "by_epoch" = False'
+        super(CyclicLrUpdaterHook, self).__init__(by_epoch, **kwargs)
+
+    def before_run(self, runner):
+        super(CyclicLrUpdaterHook, self).before_run(runner)
+        # initiate lr_phases
+        # total lr_phases are separated as up and down
+        max_iter_per_phase = runner.max_iters // self.cyclic_times
+        iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+        self.lr_phases.append(
+            [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+        self.lr_phases.append([
+            iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+            self.target_ratio[0], self.target_ratio[1]
+        ])
+
+    def get_lr(self, runner, base_lr):
+        curr_iter = runner.iter
+        for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+             end_ratio) in self.lr_phases:
+            curr_iter %= max_iter_per_phase
+            if start_iter <= curr_iter < end_iter:
+                progress = curr_iter - start_iter
+                return self.anneal_func(base_lr * start_ratio,
+                                        base_lr * end_ratio,
+                                        progress / (end_iter - start_iter))
+
+
+@HOOKS.register_module()
+class OneCycleLrUpdaterHook(LrUpdaterHook):
+    """One Cycle LR Scheduler.
+
+    The 1cycle learning rate policy changes the learning rate after every
+    batch. The one cycle learning rate policy is described in
+    https://arxiv.org/pdf/1708.07120.pdf
+
+    Args:
+        max_lr (float or list): Upper learning rate boundaries in the cycle
+            for each parameter group.
+        total_steps (int, optional): The total number of steps in the cycle.
+            Note that if a value is not provided here, it will be the max_iter
+            of runner. Default: None.
+        pct_start (float): The percentage of the cycle (in number of steps)
+            spent increasing the learning rate.
+            Default: 0.3
+        anneal_strategy (str): {'cos', 'linear'}
+            Specifies the annealing strategy: 'cos' for cosine annealing,
+            'linear' for linear annealing.
+            Default: 'cos'
+        div_factor (float): Determines the initial learning rate via
+            initial_lr = max_lr/div_factor
+            Default: 25
+        final_div_factor (float): Determines the minimum learning rate via
+            min_lr = initial_lr/final_div_factor
+            Default: 1e4
+        three_phase (bool): If three_phase is True, use a third phase of the
+            schedule to annihilate the learning rate according to
+            final_div_factor instead of modifying the second phase (the first
+            two phases will be symmetrical about the step indicated by
+            pct_start).
+            Default: False
+    """
+
+    def __init__(self,
+                 max_lr,
+                 total_steps=None,
+                 pct_start=0.3,
+                 anneal_strategy='cos',
+                 div_factor=25,
+                 final_div_factor=1e4,
+                 three_phase=False,
+                 **kwargs):
+        # validate by_epoch, currently only support by_epoch = False
+        if 'by_epoch' not in kwargs:
+            kwargs['by_epoch'] = False
+        else:
+            assert not kwargs['by_epoch'], \
+                'currently only support "by_epoch" = False'
+        if not isinstance(max_lr, (numbers.Number, list, dict)):
+            raise ValueError('the type of max_lr must be the one of list or '
+                             f'dict, but got {type(max_lr)}')
+        self._max_lr = max_lr
+        if total_steps is not None:
+            if not isinstance(total_steps, int):
+                raise ValueError('the type of total_steps must be int, but'
+                                 f'got {type(total_steps)}')
+            self.total_steps = total_steps
+        # validate pct_start
+        if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+            raise ValueError('expected float between 0 and 1 pct_start, but '
+                             f'got {pct_start}')
+        self.pct_start = pct_start
+        # validate anneal_strategy
+        if anneal_strategy not in ['cos', 'linear']:
+            raise ValueError('anneal_strategy must be one of "cos" or '
+                             f'"linear", instead got {anneal_strategy}')
+        elif anneal_strategy == 'cos':
+            self.anneal_func = annealing_cos
+        elif anneal_strategy == 'linear':
+            self.anneal_func = annealing_linear
+        self.div_factor = div_factor
+        self.final_div_factor = final_div_factor
+        self.three_phase = three_phase
+        self.lr_phases = []  # init lr_phases
+        super(OneCycleLrUpdaterHook, self).__init__(**kwargs)
+
+    def before_run(self, runner):
+        if hasattr(self, 'total_steps'):
+            total_steps = self.total_steps
+        else:
+            total_steps = runner.max_iters
+        if total_steps < runner.max_iters:
+            raise ValueError(
+                'The total steps must be greater than or equal to max '
+                f'iterations {runner.max_iters} of runner, but total steps '
+                f'is {total_steps}.')
+
+        if isinstance(runner.optimizer, dict):
+            self.base_lr = {}
+            for k, optim in runner.optimizer.items():
+                _max_lr = format_param(k, optim, self._max_lr)
+                self.base_lr[k] = [lr / self.div_factor for lr in _max_lr]
+                for group, lr in zip(optim.param_groups, self.base_lr[k]):
+                    group.setdefault('initial_lr', lr)
+        else:
+            k = type(runner.optimizer).__name__
+            _max_lr = format_param(k, runner.optimizer, self._max_lr)
+            self.base_lr = [lr / self.div_factor for lr in _max_lr]
+            for group, lr in zip(runner.optimizer.param_groups, self.base_lr):
+                group.setdefault('initial_lr', lr)
+
+        if self.three_phase:
+            self.lr_phases.append(
+                [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+            self.lr_phases.append([
+                float(2 * self.pct_start * total_steps) - 2, self.div_factor, 1
+            ])
+            self.lr_phases.append(
+                [total_steps - 1, 1, 1 / self.final_div_factor])
+        else:
+            self.lr_phases.append(
+                [float(self.pct_start * total_steps) - 1, 1, self.div_factor])
+            self.lr_phases.append(
+                [total_steps - 1, self.div_factor, 1 / self.final_div_factor])
+
+    def get_lr(self, runner, base_lr):
+        curr_iter = runner.iter
+        start_iter = 0
+        for i, (end_iter, start_lr, end_lr) in enumerate(self.lr_phases):
+            if curr_iter <= end_iter:
+                pct = (curr_iter - start_iter) / (end_iter - start_iter)
+                lr = self.anneal_func(base_lr * start_lr, base_lr * end_lr,
+                                      pct)
+                break
+            start_iter = end_iter
+        return lr
+
+
+def annealing_cos(start, end, factor, weight=1):
+    """Calculate annealing cos learning rate.
+
+    Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
+    percentage goes from 0.0 to 1.0.
+
+    Args:
+        start (float): The starting learning rate of the cosine annealing.
+        end (float): The ending learing rate of the cosine annealing.
+        factor (float): The coefficient of `pi` when calculating the current
+            percentage. Range from 0.0 to 1.0.
+        weight (float, optional): The combination factor of `start` and `end`
+            when calculating the actual starting learning rate. Default to 1.
+    """
+    cos_out = cos(pi * factor) + 1
+    return end + 0.5 * weight * (start - end) * cos_out
+
+
+def annealing_linear(start, end, factor):
+    """Calculate annealing linear learning rate.
+
+    Linear anneal from `start` to `end` as percentage goes from 0.0 to 1.0.
+
+    Args:
+        start (float): The starting learning rate of the linear annealing.
+        end (float): The ending learing rate of the linear annealing.
+        factor (float): The coefficient of `pi` when calculating the current
+            percentage. Range from 0.0 to 1.0.
+    """
+    return start + (end - start) * factor
+
+
+def format_param(name, optim, param):
+    if isinstance(param, numbers.Number):
+        return [param] * len(optim.param_groups)
+    elif isinstance(param, (list, tuple)):  # multi param groups
+        if len(param) != len(optim.param_groups):
+            raise ValueError(f'expected {len(optim.param_groups)} '
+                             f'values for {name}, got {len(param)}')
+        return param
+    else:  # multi optimizers
+        if name not in param:
+            raise KeyError(f'{name} is not found in {param.keys()}')
+        return param[name]
diff --git a/annotator/uniformer/mmcv/runner/hooks/memory.py b/annotator/uniformer/mmcv/runner/hooks/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..70cf9a838fb314e3bd3c07aadbc00921a81e83ed
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/memory.py
@@ -0,0 +1,25 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class EmptyCacheHook(Hook):
+
+    def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
+        self._before_epoch = before_epoch
+        self._after_epoch = after_epoch
+        self._after_iter = after_iter
+
+    def after_iter(self, runner):
+        if self._after_iter:
+            torch.cuda.empty_cache()
+
+    def before_epoch(self, runner):
+        if self._before_epoch:
+            torch.cuda.empty_cache()
+
+    def after_epoch(self, runner):
+        if self._after_epoch:
+            torch.cuda.empty_cache()
diff --git a/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py b/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py
new file mode 100644
index 0000000000000000000000000000000000000000..60437756ceedf06055ec349df69a25465738d3f0
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/momentum_updater.py
@@ -0,0 +1,493 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import annotator.uniformer.mmcv as mmcv
+from .hook import HOOKS, Hook
+from .lr_updater import annealing_cos, annealing_linear, format_param
+
+
+class MomentumUpdaterHook(Hook):
+
+    def __init__(self,
+                 by_epoch=True,
+                 warmup=None,
+                 warmup_iters=0,
+                 warmup_ratio=0.9):
+        # validate the "warmup" argument
+        if warmup is not None:
+            if warmup not in ['constant', 'linear', 'exp']:
+                raise ValueError(
+                    f'"{warmup}" is not a supported type for warming up, valid'
+                    ' types are "constant" and "linear"')
+        if warmup is not None:
+            assert warmup_iters > 0, \
+                '"warmup_iters" must be a positive integer'
+            assert 0 < warmup_ratio <= 1.0, \
+                '"warmup_momentum" must be in range (0,1]'
+
+        self.by_epoch = by_epoch
+        self.warmup = warmup
+        self.warmup_iters = warmup_iters
+        self.warmup_ratio = warmup_ratio
+
+        self.base_momentum = []  # initial momentum for all param groups
+        self.regular_momentum = [
+        ]  # expected momentum if no warming up is performed
+
+    def _set_momentum(self, runner, momentum_groups):
+        if isinstance(runner.optimizer, dict):
+            for k, optim in runner.optimizer.items():
+                for param_group, mom in zip(optim.param_groups,
+                                            momentum_groups[k]):
+                    if 'momentum' in param_group.keys():
+                        param_group['momentum'] = mom
+                    elif 'betas' in param_group.keys():
+                        param_group['betas'] = (mom, param_group['betas'][1])
+        else:
+            for param_group, mom in zip(runner.optimizer.param_groups,
+                                        momentum_groups):
+                if 'momentum' in param_group.keys():
+                    param_group['momentum'] = mom
+                elif 'betas' in param_group.keys():
+                    param_group['betas'] = (mom, param_group['betas'][1])
+
+    def get_momentum(self, runner, base_momentum):
+        raise NotImplementedError
+
+    def get_regular_momentum(self, runner):
+        if isinstance(runner.optimizer, dict):
+            momentum_groups = {}
+            for k in runner.optimizer.keys():
+                _momentum_group = [
+                    self.get_momentum(runner, _base_momentum)
+                    for _base_momentum in self.base_momentum[k]
+                ]
+                momentum_groups.update({k: _momentum_group})
+            return momentum_groups
+        else:
+            return [
+                self.get_momentum(runner, _base_momentum)
+                for _base_momentum in self.base_momentum
+            ]
+
+    def get_warmup_momentum(self, cur_iters):
+
+        def _get_warmup_momentum(cur_iters, regular_momentum):
+            if self.warmup == 'constant':
+                warmup_momentum = [
+                    _momentum / self.warmup_ratio
+                    for _momentum in self.regular_momentum
+                ]
+            elif self.warmup == 'linear':
+                k = (1 - cur_iters / self.warmup_iters) * (1 -
+                                                           self.warmup_ratio)
+                warmup_momentum = [
+                    _momentum / (1 - k) for _momentum in self.regular_mom
+                ]
+            elif self.warmup == 'exp':
+                k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+                warmup_momentum = [
+                    _momentum / k for _momentum in self.regular_mom
+                ]
+            return warmup_momentum
+
+        if isinstance(self.regular_momentum, dict):
+            momentum_groups = {}
+            for key, regular_momentum in self.regular_momentum.items():
+                momentum_groups[key] = _get_warmup_momentum(
+                    cur_iters, regular_momentum)
+            return momentum_groups
+        else:
+            return _get_warmup_momentum(cur_iters, self.regular_momentum)
+
+    def before_run(self, runner):
+        # NOTE: when resuming from a checkpoint,
+        # if 'initial_momentum' is not saved,
+        # it will be set according to the optimizer params
+        if isinstance(runner.optimizer, dict):
+            self.base_momentum = {}
+            for k, optim in runner.optimizer.items():
+                for group in optim.param_groups:
+                    if 'momentum' in group.keys():
+                        group.setdefault('initial_momentum', group['momentum'])
+                    else:
+                        group.setdefault('initial_momentum', group['betas'][0])
+                _base_momentum = [
+                    group['initial_momentum'] for group in optim.param_groups
+                ]
+                self.base_momentum.update({k: _base_momentum})
+        else:
+            for group in runner.optimizer.param_groups:
+                if 'momentum' in group.keys():
+                    group.setdefault('initial_momentum', group['momentum'])
+                else:
+                    group.setdefault('initial_momentum', group['betas'][0])
+            self.base_momentum = [
+                group['initial_momentum']
+                for group in runner.optimizer.param_groups
+            ]
+
+    def before_train_epoch(self, runner):
+        if not self.by_epoch:
+            return
+        self.regular_mom = self.get_regular_momentum(runner)
+        self._set_momentum(runner, self.regular_mom)
+
+    def before_train_iter(self, runner):
+        cur_iter = runner.iter
+        if not self.by_epoch:
+            self.regular_mom = self.get_regular_momentum(runner)
+            if self.warmup is None or cur_iter >= self.warmup_iters:
+                self._set_momentum(runner, self.regular_mom)
+            else:
+                warmup_momentum = self.get_warmup_momentum(cur_iter)
+                self._set_momentum(runner, warmup_momentum)
+        elif self.by_epoch:
+            if self.warmup is None or cur_iter > self.warmup_iters:
+                return
+            elif cur_iter == self.warmup_iters:
+                self._set_momentum(runner, self.regular_mom)
+            else:
+                warmup_momentum = self.get_warmup_momentum(cur_iter)
+                self._set_momentum(runner, warmup_momentum)
+
+
+@HOOKS.register_module()
+class StepMomentumUpdaterHook(MomentumUpdaterHook):
+    """Step momentum scheduler with min value clipping.
+
+    Args:
+        step (int | list[int]): Step to decay the momentum. If an int value is
+            given, regard it as the decay interval. If a list is given, decay
+            momentum at these steps.
+        gamma (float, optional): Decay momentum ratio. Default: 0.5.
+        min_momentum (float, optional): Minimum momentum value to keep. If
+            momentum after decay is lower than this value, it will be clipped
+            accordingly. If None is given, we don't perform lr clipping.
+            Default: None.
+    """
+
+    def __init__(self, step, gamma=0.5, min_momentum=None, **kwargs):
+        if isinstance(step, list):
+            assert mmcv.is_list_of(step, int)
+            assert all([s > 0 for s in step])
+        elif isinstance(step, int):
+            assert step > 0
+        else:
+            raise TypeError('"step" must be a list or integer')
+        self.step = step
+        self.gamma = gamma
+        self.min_momentum = min_momentum
+        super(StepMomentumUpdaterHook, self).__init__(**kwargs)
+
+    def get_momentum(self, runner, base_momentum):
+        progress = runner.epoch if self.by_epoch else runner.iter
+
+        # calculate exponential term
+        if isinstance(self.step, int):
+            exp = progress // self.step
+        else:
+            exp = len(self.step)
+            for i, s in enumerate(self.step):
+                if progress < s:
+                    exp = i
+                    break
+
+        momentum = base_momentum * (self.gamma**exp)
+        if self.min_momentum is not None:
+            # clip to a minimum value
+            momentum = max(momentum, self.min_momentum)
+        return momentum
+
+
+@HOOKS.register_module()
+class CosineAnnealingMomentumUpdaterHook(MomentumUpdaterHook):
+
+    def __init__(self, min_momentum=None, min_momentum_ratio=None, **kwargs):
+        assert (min_momentum is None) ^ (min_momentum_ratio is None)
+        self.min_momentum = min_momentum
+        self.min_momentum_ratio = min_momentum_ratio
+        super(CosineAnnealingMomentumUpdaterHook, self).__init__(**kwargs)
+
+    def get_momentum(self, runner, base_momentum):
+        if self.by_epoch:
+            progress = runner.epoch
+            max_progress = runner.max_epochs
+        else:
+            progress = runner.iter
+            max_progress = runner.max_iters
+        if self.min_momentum_ratio is not None:
+            target_momentum = base_momentum * self.min_momentum_ratio
+        else:
+            target_momentum = self.min_momentum
+        return annealing_cos(base_momentum, target_momentum,
+                             progress / max_progress)
+
+
+@HOOKS.register_module()
+class CyclicMomentumUpdaterHook(MomentumUpdaterHook):
+    """Cyclic momentum Scheduler.
+
+    Implement the cyclical momentum scheduler policy described in
+    https://arxiv.org/pdf/1708.07120.pdf
+
+    This momentum scheduler usually used together with the CyclicLRUpdater
+    to improve the performance in the 3D detection area.
+
+    Attributes:
+        target_ratio (tuple[float]): Relative ratio of the lowest momentum and
+            the highest momentum to the initial momentum.
+        cyclic_times (int): Number of cycles during training
+        step_ratio_up (float): The ratio of the increasing process of momentum
+            in  the total cycle.
+        by_epoch (bool): Whether to update momentum by epoch.
+    """
+
+    def __init__(self,
+                 by_epoch=False,
+                 target_ratio=(0.85 / 0.95, 1),
+                 cyclic_times=1,
+                 step_ratio_up=0.4,
+                 **kwargs):
+        if isinstance(target_ratio, float):
+            target_ratio = (target_ratio, target_ratio / 1e5)
+        elif isinstance(target_ratio, tuple):
+            target_ratio = (target_ratio[0], target_ratio[0] / 1e5) \
+                if len(target_ratio) == 1 else target_ratio
+        else:
+            raise ValueError('target_ratio should be either float '
+                             f'or tuple, got {type(target_ratio)}')
+
+        assert len(target_ratio) == 2, \
+            '"target_ratio" must be list or tuple of two floats'
+        assert 0 <= step_ratio_up < 1.0, \
+            '"step_ratio_up" must be in range [0,1)'
+
+        self.target_ratio = target_ratio
+        self.cyclic_times = cyclic_times
+        self.step_ratio_up = step_ratio_up
+        self.momentum_phases = []  # init momentum_phases
+        # currently only support by_epoch=False
+        assert not by_epoch, \
+            'currently only support "by_epoch" = False'
+        super(CyclicMomentumUpdaterHook, self).__init__(by_epoch, **kwargs)
+
+    def before_run(self, runner):
+        super(CyclicMomentumUpdaterHook, self).before_run(runner)
+        # initiate momentum_phases
+        # total momentum_phases are separated as up and down
+        max_iter_per_phase = runner.max_iters // self.cyclic_times
+        iter_up_phase = int(self.step_ratio_up * max_iter_per_phase)
+        self.momentum_phases.append(
+            [0, iter_up_phase, max_iter_per_phase, 1, self.target_ratio[0]])
+        self.momentum_phases.append([
+            iter_up_phase, max_iter_per_phase, max_iter_per_phase,
+            self.target_ratio[0], self.target_ratio[1]
+        ])
+
+    def get_momentum(self, runner, base_momentum):
+        curr_iter = runner.iter
+        for (start_iter, end_iter, max_iter_per_phase, start_ratio,
+             end_ratio) in self.momentum_phases:
+            curr_iter %= max_iter_per_phase
+            if start_iter <= curr_iter < end_iter:
+                progress = curr_iter - start_iter
+                return annealing_cos(base_momentum * start_ratio,
+                                     base_momentum * end_ratio,
+                                     progress / (end_iter - start_iter))
+
+
+@HOOKS.register_module()
+class OneCycleMomentumUpdaterHook(MomentumUpdaterHook):
+    """OneCycle momentum Scheduler.
+
+    This momentum scheduler usually used together with the OneCycleLrUpdater
+    to improve the performance.
+
+    Args:
+        base_momentum (float or list): Lower momentum boundaries in the cycle
+            for each parameter group. Note that momentum is cycled inversely
+            to learning rate; at the peak of a cycle, momentum is
+            'base_momentum' and learning rate is 'max_lr'.
+            Default: 0.85
+        max_momentum (float or list): Upper momentum boundaries in the cycle
+            for each parameter group. Functionally,
+            it defines the cycle amplitude (max_momentum - base_momentum).
+            Note that momentum is cycled inversely
+            to learning rate; at the start of a cycle, momentum is
+            'max_momentum' and learning rate is 'base_lr'
+            Default: 0.95
+        pct_start (float): The percentage of the cycle (in number of steps)
+            spent increasing the learning rate.
+            Default: 0.3
+        anneal_strategy (str): {'cos', 'linear'}
+            Specifies the annealing strategy: 'cos' for cosine annealing,
+            'linear' for linear annealing.
+            Default: 'cos'
+        three_phase (bool): If three_phase is True, use a third phase of the
+            schedule to annihilate the learning rate according to
+            final_div_factor instead of modifying the second phase (the first
+            two phases will be symmetrical about the step indicated by
+            pct_start).
+            Default: False
+    """
+
+    def __init__(self,
+                 base_momentum=0.85,
+                 max_momentum=0.95,
+                 pct_start=0.3,
+                 anneal_strategy='cos',
+                 three_phase=False,
+                 **kwargs):
+        # validate by_epoch, currently only support by_epoch=False
+        if 'by_epoch' not in kwargs:
+            kwargs['by_epoch'] = False
+        else:
+            assert not kwargs['by_epoch'], \
+                'currently only support "by_epoch" = False'
+        if not isinstance(base_momentum, (float, list, dict)):
+            raise ValueError('base_momentum must be the type among of float,'
+                             'list or dict.')
+        self._base_momentum = base_momentum
+        if not isinstance(max_momentum, (float, list, dict)):
+            raise ValueError('max_momentum must be the type among of float,'
+                             'list or dict.')
+        self._max_momentum = max_momentum
+        # validate pct_start
+        if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
+            raise ValueError('Expected float between 0 and 1 pct_start, but '
+                             f'got {pct_start}')
+        self.pct_start = pct_start
+        # validate anneal_strategy
+        if anneal_strategy not in ['cos', 'linear']:
+            raise ValueError('anneal_strategy must by one of "cos" or '
+                             f'"linear", instead got {anneal_strategy}')
+        elif anneal_strategy == 'cos':
+            self.anneal_func = annealing_cos
+        elif anneal_strategy == 'linear':
+            self.anneal_func = annealing_linear
+        self.three_phase = three_phase
+        self.momentum_phases = []  # init momentum_phases
+        super(OneCycleMomentumUpdaterHook, self).__init__(**kwargs)
+
+    def before_run(self, runner):
+        if isinstance(runner.optimizer, dict):
+            for k, optim in runner.optimizer.items():
+                if ('momentum' not in optim.defaults
+                        and 'betas' not in optim.defaults):
+                    raise ValueError('optimizer must support momentum with'
+                                     'option enabled')
+                self.use_beta1 = 'betas' in optim.defaults
+                _base_momentum = format_param(k, optim, self._base_momentum)
+                _max_momentum = format_param(k, optim, self._max_momentum)
+                for group, b_momentum, m_momentum in zip(
+                        optim.param_groups, _base_momentum, _max_momentum):
+                    if self.use_beta1:
+                        _, beta2 = group['betas']
+                        group['betas'] = (m_momentum, beta2)
+                    else:
+                        group['momentum'] = m_momentum
+                    group['base_momentum'] = b_momentum
+                    group['max_momentum'] = m_momentum
+        else:
+            optim = runner.optimizer
+            if ('momentum' not in optim.defaults
+                    and 'betas' not in optim.defaults):
+                raise ValueError('optimizer must support momentum with'
+                                 'option enabled')
+            self.use_beta1 = 'betas' in optim.defaults
+            k = type(optim).__name__
+            _base_momentum = format_param(k, optim, self._base_momentum)
+            _max_momentum = format_param(k, optim, self._max_momentum)
+            for group, b_momentum, m_momentum in zip(optim.param_groups,
+                                                     _base_momentum,
+                                                     _max_momentum):
+                if self.use_beta1:
+                    _, beta2 = group['betas']
+                    group['betas'] = (m_momentum, beta2)
+                else:
+                    group['momentum'] = m_momentum
+                group['base_momentum'] = b_momentum
+                group['max_momentum'] = m_momentum
+
+        if self.three_phase:
+            self.momentum_phases.append({
+                'end_iter':
+                float(self.pct_start * runner.max_iters) - 1,
+                'start_momentum':
+                'max_momentum',
+                'end_momentum':
+                'base_momentum'
+            })
+            self.momentum_phases.append({
+                'end_iter':
+                float(2 * self.pct_start * runner.max_iters) - 2,
+                'start_momentum':
+                'base_momentum',
+                'end_momentum':
+                'max_momentum'
+            })
+            self.momentum_phases.append({
+                'end_iter': runner.max_iters - 1,
+                'start_momentum': 'max_momentum',
+                'end_momentum': 'max_momentum'
+            })
+        else:
+            self.momentum_phases.append({
+                'end_iter':
+                float(self.pct_start * runner.max_iters) - 1,
+                'start_momentum':
+                'max_momentum',
+                'end_momentum':
+                'base_momentum'
+            })
+            self.momentum_phases.append({
+                'end_iter': runner.max_iters - 1,
+                'start_momentum': 'base_momentum',
+                'end_momentum': 'max_momentum'
+            })
+
+    def _set_momentum(self, runner, momentum_groups):
+        if isinstance(runner.optimizer, dict):
+            for k, optim in runner.optimizer.items():
+                for param_group, mom in zip(optim.param_groups,
+                                            momentum_groups[k]):
+                    if 'momentum' in param_group.keys():
+                        param_group['momentum'] = mom
+                    elif 'betas' in param_group.keys():
+                        param_group['betas'] = (mom, param_group['betas'][1])
+        else:
+            for param_group, mom in zip(runner.optimizer.param_groups,
+                                        momentum_groups):
+                if 'momentum' in param_group.keys():
+                    param_group['momentum'] = mom
+                elif 'betas' in param_group.keys():
+                    param_group['betas'] = (mom, param_group['betas'][1])
+
+    def get_momentum(self, runner, param_group):
+        curr_iter = runner.iter
+        start_iter = 0
+        for i, phase in enumerate(self.momentum_phases):
+            end_iter = phase['end_iter']
+            if curr_iter <= end_iter or i == len(self.momentum_phases) - 1:
+                pct = (curr_iter - start_iter) / (end_iter - start_iter)
+                momentum = self.anneal_func(
+                    param_group[phase['start_momentum']],
+                    param_group[phase['end_momentum']], pct)
+                break
+            start_iter = end_iter
+        return momentum
+
+    def get_regular_momentum(self, runner):
+        if isinstance(runner.optimizer, dict):
+            momentum_groups = {}
+            for k, optim in runner.optimizer.items():
+                _momentum_group = [
+                    self.get_momentum(runner, param_group)
+                    for param_group in optim.param_groups
+                ]
+                momentum_groups.update({k: _momentum_group})
+            return momentum_groups
+        else:
+            momentum_groups = []
+            for param_group in runner.optimizer.param_groups:
+                momentum_groups.append(self.get_momentum(runner, param_group))
+            return momentum_groups
diff --git a/annotator/uniformer/mmcv/runner/hooks/optimizer.py b/annotator/uniformer/mmcv/runner/hooks/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ef3e9ff8f9c6926e32bdf027612267b64ed80df
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/optimizer.py
@@ -0,0 +1,508 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from collections import defaultdict
+from itertools import chain
+
+from torch.nn.utils import clip_grad
+
+from annotator.uniformer.mmcv.utils import TORCH_VERSION, _BatchNorm, digit_version
+from ..dist_utils import allreduce_grads
+from ..fp16_utils import LossScaler, wrap_fp16_model
+from .hook import HOOKS, Hook
+
+try:
+    # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported
+    # and used; otherwise, auto fp16 will adopt mmcv's implementation.
+    from torch.cuda.amp import GradScaler
+except ImportError:
+    pass
+
+
+@HOOKS.register_module()
+class OptimizerHook(Hook):
+
+    def __init__(self, grad_clip=None):
+        self.grad_clip = grad_clip
+
+    def clip_grads(self, params):
+        params = list(
+            filter(lambda p: p.requires_grad and p.grad is not None, params))
+        if len(params) > 0:
+            return clip_grad.clip_grad_norm_(params, **self.grad_clip)
+
+    def after_train_iter(self, runner):
+        runner.optimizer.zero_grad()
+        runner.outputs['loss'].backward()
+        if self.grad_clip is not None:
+            grad_norm = self.clip_grads(runner.model.parameters())
+            if grad_norm is not None:
+                # Add grad norm to the logger
+                runner.log_buffer.update({'grad_norm': float(grad_norm)},
+                                         runner.outputs['num_samples'])
+        runner.optimizer.step()
+
+
+@HOOKS.register_module()
+class GradientCumulativeOptimizerHook(OptimizerHook):
+    """Optimizer Hook implements multi-iters gradient cumulating.
+
+    Args:
+        cumulative_iters (int, optional): Num of gradient cumulative iters.
+            The optimizer will step every `cumulative_iters` iters.
+            Defaults to 1.
+
+    Examples:
+        >>> # Use cumulative_iters to simulate a large batch size
+        >>> # It is helpful when the hardware cannot handle a large batch size.
+        >>> loader = DataLoader(data, batch_size=64)
+        >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4)
+        >>> # almost equals to
+        >>> loader = DataLoader(data, batch_size=256)
+        >>> optim_hook = OptimizerHook()
+    """
+
+    def __init__(self, cumulative_iters=1, **kwargs):
+        super(GradientCumulativeOptimizerHook, self).__init__(**kwargs)
+
+        assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \
+            f'cumulative_iters only accepts positive int, but got ' \
+            f'{type(cumulative_iters)} instead.'
+
+        self.cumulative_iters = cumulative_iters
+        self.divisible_iters = 0
+        self.remainder_iters = 0
+        self.initialized = False
+
+    def has_batch_norm(self, module):
+        if isinstance(module, _BatchNorm):
+            return True
+        for m in module.children():
+            if self.has_batch_norm(m):
+                return True
+        return False
+
+    def _init(self, runner):
+        if runner.iter % self.cumulative_iters != 0:
+            runner.logger.warning(
+                'Resume iter number is not divisible by cumulative_iters in '
+                'GradientCumulativeOptimizerHook, which means the gradient of '
+                'some iters is lost and the result may be influenced slightly.'
+            )
+
+        if self.has_batch_norm(runner.model) and self.cumulative_iters > 1:
+            runner.logger.warning(
+                'GradientCumulativeOptimizerHook may slightly decrease '
+                'performance if the model has BatchNorm layers.')
+
+        residual_iters = runner.max_iters - runner.iter
+
+        self.divisible_iters = (
+            residual_iters // self.cumulative_iters * self.cumulative_iters)
+        self.remainder_iters = residual_iters - self.divisible_iters
+
+        self.initialized = True
+
+    def after_train_iter(self, runner):
+        if not self.initialized:
+            self._init(runner)
+
+        if runner.iter < self.divisible_iters:
+            loss_factor = self.cumulative_iters
+        else:
+            loss_factor = self.remainder_iters
+        loss = runner.outputs['loss']
+        loss = loss / loss_factor
+        loss.backward()
+
+        if (self.every_n_iters(runner, self.cumulative_iters)
+                or self.is_last_iter(runner)):
+
+            if self.grad_clip is not None:
+                grad_norm = self.clip_grads(runner.model.parameters())
+                if grad_norm is not None:
+                    # Add grad norm to the logger
+                    runner.log_buffer.update({'grad_norm': float(grad_norm)},
+                                             runner.outputs['num_samples'])
+            runner.optimizer.step()
+            runner.optimizer.zero_grad()
+
+
+if (TORCH_VERSION != 'parrots'
+        and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
+
+    @HOOKS.register_module()
+    class Fp16OptimizerHook(OptimizerHook):
+        """FP16 optimizer hook (using PyTorch's implementation).
+
+        If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+        to take care of the optimization procedure.
+
+        Args:
+            loss_scale (float | str | dict): Scale factor configuration.
+                If loss_scale is a float, static loss scaling will be used with
+                the specified scale. If loss_scale is a string, it must be
+                'dynamic', then dynamic loss scaling will be used.
+                It can also be a dict containing arguments of GradScalar.
+                Defaults to 512. For Pytorch >= 1.6, mmcv uses official
+                implementation of GradScaler. If you use a dict version of
+                loss_scale to create GradScaler, please refer to:
+                https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler
+                for the parameters.
+
+        Examples:
+            >>> loss_scale = dict(
+            ...     init_scale=65536.0,
+            ...     growth_factor=2.0,
+            ...     backoff_factor=0.5,
+            ...     growth_interval=2000
+            ... )
+            >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale)
+        """
+
+        def __init__(self,
+                     grad_clip=None,
+                     coalesce=True,
+                     bucket_size_mb=-1,
+                     loss_scale=512.,
+                     distributed=True):
+            self.grad_clip = grad_clip
+            self.coalesce = coalesce
+            self.bucket_size_mb = bucket_size_mb
+            self.distributed = distributed
+            self._scale_update_param = None
+            if loss_scale == 'dynamic':
+                self.loss_scaler = GradScaler()
+            elif isinstance(loss_scale, float):
+                self._scale_update_param = loss_scale
+                self.loss_scaler = GradScaler(init_scale=loss_scale)
+            elif isinstance(loss_scale, dict):
+                self.loss_scaler = GradScaler(**loss_scale)
+            else:
+                raise ValueError('loss_scale must be of type float, dict, or '
+                                 f'"dynamic", got {loss_scale}')
+
+        def before_run(self, runner):
+            """Preparing steps before Mixed Precision Training."""
+            # wrap model mode to fp16
+            wrap_fp16_model(runner.model)
+            # resume from state dict
+            if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+                scaler_state_dict = runner.meta['fp16']['loss_scaler']
+                self.loss_scaler.load_state_dict(scaler_state_dict)
+
+        def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+            """Copy gradients from fp16 model to fp32 weight copy."""
+            for fp32_param, fp16_param in zip(fp32_weights,
+                                              fp16_net.parameters()):
+                if fp16_param.grad is not None:
+                    if fp32_param.grad is None:
+                        fp32_param.grad = fp32_param.data.new(
+                            fp32_param.size())
+                    fp32_param.grad.copy_(fp16_param.grad)
+
+        def copy_params_to_fp16(self, fp16_net, fp32_weights):
+            """Copy updated params from fp32 weight copy to fp16 model."""
+            for fp16_param, fp32_param in zip(fp16_net.parameters(),
+                                              fp32_weights):
+                fp16_param.data.copy_(fp32_param.data)
+
+        def after_train_iter(self, runner):
+            """Backward optimization steps for Mixed Precision Training. For
+            dynamic loss scaling, please refer to
+            https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.
+
+            1. Scale the loss by a scale factor.
+            2. Backward the loss to obtain the gradients.
+            3. Unscale the optimizer’s gradient tensors.
+            4. Call optimizer.step() and update scale factor.
+            5. Save loss_scaler state_dict for resume purpose.
+            """
+            # clear grads of last iteration
+            runner.model.zero_grad()
+            runner.optimizer.zero_grad()
+
+            self.loss_scaler.scale(runner.outputs['loss']).backward()
+            self.loss_scaler.unscale_(runner.optimizer)
+            # grad clip
+            if self.grad_clip is not None:
+                grad_norm = self.clip_grads(runner.model.parameters())
+                if grad_norm is not None:
+                    # Add grad norm to the logger
+                    runner.log_buffer.update({'grad_norm': float(grad_norm)},
+                                             runner.outputs['num_samples'])
+            # backward and update scaler
+            self.loss_scaler.step(runner.optimizer)
+            self.loss_scaler.update(self._scale_update_param)
+
+            # save state_dict of loss_scaler
+            runner.meta.setdefault(
+                'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+    @HOOKS.register_module()
+    class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+                                              Fp16OptimizerHook):
+        """Fp16 optimizer Hook (using PyTorch's implementation) implements
+        multi-iters gradient cumulating.
+
+        If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend,
+        to take care of the optimization procedure.
+        """
+
+        def __init__(self, *args, **kwargs):
+            super(GradientCumulativeFp16OptimizerHook,
+                  self).__init__(*args, **kwargs)
+
+        def after_train_iter(self, runner):
+            if not self.initialized:
+                self._init(runner)
+
+            if runner.iter < self.divisible_iters:
+                loss_factor = self.cumulative_iters
+            else:
+                loss_factor = self.remainder_iters
+            loss = runner.outputs['loss']
+            loss = loss / loss_factor
+
+            self.loss_scaler.scale(loss).backward()
+
+            if (self.every_n_iters(runner, self.cumulative_iters)
+                    or self.is_last_iter(runner)):
+
+                # copy fp16 grads in the model to fp32 params in the optimizer
+                self.loss_scaler.unscale_(runner.optimizer)
+
+                if self.grad_clip is not None:
+                    grad_norm = self.clip_grads(runner.model.parameters())
+                    if grad_norm is not None:
+                        # Add grad norm to the logger
+                        runner.log_buffer.update(
+                            {'grad_norm': float(grad_norm)},
+                            runner.outputs['num_samples'])
+
+                # backward and update scaler
+                self.loss_scaler.step(runner.optimizer)
+                self.loss_scaler.update(self._scale_update_param)
+
+                # save state_dict of loss_scaler
+                runner.meta.setdefault(
+                    'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+                # clear grads
+                runner.model.zero_grad()
+                runner.optimizer.zero_grad()
+
+else:
+
+    @HOOKS.register_module()
+    class Fp16OptimizerHook(OptimizerHook):
+        """FP16 optimizer hook (mmcv's implementation).
+
+        The steps of fp16 optimizer is as follows.
+        1. Scale the loss value.
+        2. BP in the fp16 model.
+        2. Copy gradients from fp16 model to fp32 weights.
+        3. Update fp32 weights.
+        4. Copy updated parameters from fp32 weights to fp16 model.
+
+        Refer to https://arxiv.org/abs/1710.03740 for more details.
+
+        Args:
+            loss_scale (float | str | dict): Scale factor configuration.
+                If loss_scale is a float, static loss scaling will be used with
+                the specified scale. If loss_scale is a string, it must be
+                'dynamic', then dynamic loss scaling will be used.
+                It can also be a dict containing arguments of LossScaler.
+                Defaults to 512.
+        """
+
+        def __init__(self,
+                     grad_clip=None,
+                     coalesce=True,
+                     bucket_size_mb=-1,
+                     loss_scale=512.,
+                     distributed=True):
+            self.grad_clip = grad_clip
+            self.coalesce = coalesce
+            self.bucket_size_mb = bucket_size_mb
+            self.distributed = distributed
+            if loss_scale == 'dynamic':
+                self.loss_scaler = LossScaler(mode='dynamic')
+            elif isinstance(loss_scale, float):
+                self.loss_scaler = LossScaler(
+                    init_scale=loss_scale, mode='static')
+            elif isinstance(loss_scale, dict):
+                self.loss_scaler = LossScaler(**loss_scale)
+            else:
+                raise ValueError('loss_scale must be of type float, dict, or '
+                                 f'"dynamic", got {loss_scale}')
+
+        def before_run(self, runner):
+            """Preparing steps before Mixed Precision Training.
+
+            1. Make a master copy of fp32 weights for optimization.
+            2. Convert the main model from fp32 to fp16.
+            """
+            # keep a copy of fp32 weights
+            old_groups = runner.optimizer.param_groups
+            runner.optimizer.param_groups = copy.deepcopy(
+                runner.optimizer.param_groups)
+            state = defaultdict(dict)
+            p_map = {
+                old_p: p
+                for old_p, p in zip(
+                    chain(*(g['params'] for g in old_groups)),
+                    chain(*(g['params']
+                            for g in runner.optimizer.param_groups)))
+            }
+            for k, v in runner.optimizer.state.items():
+                state[p_map[k]] = v
+            runner.optimizer.state = state
+            # convert model to fp16
+            wrap_fp16_model(runner.model)
+            # resume from state dict
+            if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']:
+                scaler_state_dict = runner.meta['fp16']['loss_scaler']
+                self.loss_scaler.load_state_dict(scaler_state_dict)
+
+        def copy_grads_to_fp32(self, fp16_net, fp32_weights):
+            """Copy gradients from fp16 model to fp32 weight copy."""
+            for fp32_param, fp16_param in zip(fp32_weights,
+                                              fp16_net.parameters()):
+                if fp16_param.grad is not None:
+                    if fp32_param.grad is None:
+                        fp32_param.grad = fp32_param.data.new(
+                            fp32_param.size())
+                    fp32_param.grad.copy_(fp16_param.grad)
+
+        def copy_params_to_fp16(self, fp16_net, fp32_weights):
+            """Copy updated params from fp32 weight copy to fp16 model."""
+            for fp16_param, fp32_param in zip(fp16_net.parameters(),
+                                              fp32_weights):
+                fp16_param.data.copy_(fp32_param.data)
+
+        def after_train_iter(self, runner):
+            """Backward optimization steps for Mixed Precision Training. For
+            dynamic loss scaling, please refer `loss_scalar.py`
+
+            1. Scale the loss by a scale factor.
+            2. Backward the loss to obtain the gradients (fp16).
+            3. Copy gradients from the model to the fp32 weight copy.
+            4. Scale the gradients back and update the fp32 weight copy.
+            5. Copy back the params from fp32 weight copy to the fp16 model.
+            6. Save loss_scaler state_dict for resume purpose.
+            """
+            # clear grads of last iteration
+            runner.model.zero_grad()
+            runner.optimizer.zero_grad()
+            # scale the loss value
+            scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale
+            scaled_loss.backward()
+            # copy fp16 grads in the model to fp32 params in the optimizer
+
+            fp32_weights = []
+            for param_group in runner.optimizer.param_groups:
+                fp32_weights += param_group['params']
+            self.copy_grads_to_fp32(runner.model, fp32_weights)
+            # allreduce grads
+            if self.distributed:
+                allreduce_grads(fp32_weights, self.coalesce,
+                                self.bucket_size_mb)
+
+            has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+            # if has overflow, skip this iteration
+            if not has_overflow:
+                # scale the gradients back
+                for param in fp32_weights:
+                    if param.grad is not None:
+                        param.grad.div_(self.loss_scaler.loss_scale)
+                if self.grad_clip is not None:
+                    grad_norm = self.clip_grads(fp32_weights)
+                    if grad_norm is not None:
+                        # Add grad norm to the logger
+                        runner.log_buffer.update(
+                            {'grad_norm': float(grad_norm)},
+                            runner.outputs['num_samples'])
+                # update fp32 params
+                runner.optimizer.step()
+                # copy fp32 params to the fp16 model
+                self.copy_params_to_fp16(runner.model, fp32_weights)
+            self.loss_scaler.update_scale(has_overflow)
+            if has_overflow:
+                runner.logger.warning('Check overflow, downscale loss scale '
+                                      f'to {self.loss_scaler.cur_scale}')
+
+            # save state_dict of loss_scaler
+            runner.meta.setdefault(
+                'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+    @HOOKS.register_module()
+    class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook,
+                                              Fp16OptimizerHook):
+        """Fp16 optimizer Hook (using mmcv implementation) implements multi-
+        iters gradient cumulating."""
+
+        def __init__(self, *args, **kwargs):
+            super(GradientCumulativeFp16OptimizerHook,
+                  self).__init__(*args, **kwargs)
+
+        def after_train_iter(self, runner):
+            if not self.initialized:
+                self._init(runner)
+
+            if runner.iter < self.divisible_iters:
+                loss_factor = self.cumulative_iters
+            else:
+                loss_factor = self.remainder_iters
+
+            loss = runner.outputs['loss']
+            loss = loss / loss_factor
+
+            # scale the loss value
+            scaled_loss = loss * self.loss_scaler.loss_scale
+            scaled_loss.backward()
+
+            if (self.every_n_iters(runner, self.cumulative_iters)
+                    or self.is_last_iter(runner)):
+
+                # copy fp16 grads in the model to fp32 params in the optimizer
+                fp32_weights = []
+                for param_group in runner.optimizer.param_groups:
+                    fp32_weights += param_group['params']
+                self.copy_grads_to_fp32(runner.model, fp32_weights)
+                # allreduce grads
+                if self.distributed:
+                    allreduce_grads(fp32_weights, self.coalesce,
+                                    self.bucket_size_mb)
+
+                has_overflow = self.loss_scaler.has_overflow(fp32_weights)
+                # if has overflow, skip this iteration
+                if not has_overflow:
+                    # scale the gradients back
+                    for param in fp32_weights:
+                        if param.grad is not None:
+                            param.grad.div_(self.loss_scaler.loss_scale)
+                    if self.grad_clip is not None:
+                        grad_norm = self.clip_grads(fp32_weights)
+                        if grad_norm is not None:
+                            # Add grad norm to the logger
+                            runner.log_buffer.update(
+                                {'grad_norm': float(grad_norm)},
+                                runner.outputs['num_samples'])
+                    # update fp32 params
+                    runner.optimizer.step()
+                    # copy fp32 params to the fp16 model
+                    self.copy_params_to_fp16(runner.model, fp32_weights)
+                else:
+                    runner.logger.warning(
+                        'Check overflow, downscale loss scale '
+                        f'to {self.loss_scaler.cur_scale}')
+
+                self.loss_scaler.update_scale(has_overflow)
+
+                # save state_dict of loss_scaler
+                runner.meta.setdefault(
+                    'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict()
+
+                # clear grads
+                runner.model.zero_grad()
+                runner.optimizer.zero_grad()
diff --git a/annotator/uniformer/mmcv/runner/hooks/profiler.py b/annotator/uniformer/mmcv/runner/hooks/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b70236997eec59c2209ef351ae38863b4112d0ec
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/profiler.py
@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Callable, List, Optional, Union
+
+import torch
+
+from ..dist_utils import master_only
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class ProfilerHook(Hook):
+    """Profiler to analyze performance during training.
+
+    PyTorch Profiler is a tool that allows the collection of the performance
+    metrics during the training. More details on Profiler can be found at
+    https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile
+
+    Args:
+        by_epoch (bool): Profile performance by epoch or by iteration.
+            Default: True.
+        profile_iters (int): Number of iterations for profiling.
+            If ``by_epoch=True``, profile_iters indicates that they are the
+            first profile_iters epochs at the beginning of the
+            training, otherwise it indicates the first profile_iters
+            iterations. Default: 1.
+        activities (list[str]): List of activity groups (CPU, CUDA) to use in
+            profiling. Default: ['cpu', 'cuda'].
+        schedule (dict, optional): Config of generating the callable schedule.
+            if schedule is None, profiler will not add step markers into the
+            trace and table view. Default: None.
+        on_trace_ready (callable, dict): Either a handler or a dict of generate
+            handler. Default: None.
+        record_shapes (bool): Save information about operator's input shapes.
+            Default: False.
+        profile_memory (bool): Track tensor memory allocation/deallocation.
+            Default: False.
+        with_stack (bool): Record source information (file and line number)
+            for the ops. Default: False.
+        with_flops (bool): Use formula to estimate the FLOPS of specific
+            operators (matrix multiplication and 2D convolution).
+            Default: False.
+        json_trace_path (str, optional): Exports the collected trace in Chrome
+            JSON format. Default: None.
+
+    Example:
+        >>> runner = ... # instantiate a Runner
+        >>> # tensorboard trace
+        >>> trace_config = dict(type='tb_trace', dir_name='work_dir')
+        >>> profiler_config = dict(on_trace_ready=trace_config)
+        >>> runner.register_profiler_hook(profiler_config)
+        >>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)])
+    """
+
+    def __init__(self,
+                 by_epoch: bool = True,
+                 profile_iters: int = 1,
+                 activities: List[str] = ['cpu', 'cuda'],
+                 schedule: Optional[dict] = None,
+                 on_trace_ready: Optional[Union[Callable, dict]] = None,
+                 record_shapes: bool = False,
+                 profile_memory: bool = False,
+                 with_stack: bool = False,
+                 with_flops: bool = False,
+                 json_trace_path: Optional[str] = None) -> None:
+        try:
+            from torch import profiler  # torch version >= 1.8.1
+        except ImportError:
+            raise ImportError('profiler is the new feature of torch1.8.1, '
+                              f'but your version is {torch.__version__}')
+
+        assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
+        self.by_epoch = by_epoch
+
+        if profile_iters < 1:
+            raise ValueError('profile_iters should be greater than 0, but got '
+                             f'{profile_iters}')
+        self.profile_iters = profile_iters
+
+        if not isinstance(activities, list):
+            raise ValueError(
+                f'activities should be list, but got {type(activities)}')
+        self.activities = []
+        for activity in activities:
+            activity = activity.lower()
+            if activity == 'cpu':
+                self.activities.append(profiler.ProfilerActivity.CPU)
+            elif activity == 'cuda':
+                self.activities.append(profiler.ProfilerActivity.CUDA)
+            else:
+                raise ValueError(
+                    f'activity should be "cpu" or "cuda", but got {activity}')
+
+        if schedule is not None:
+            self.schedule = profiler.schedule(**schedule)
+        else:
+            self.schedule = None
+
+        self.on_trace_ready = on_trace_ready
+        self.record_shapes = record_shapes
+        self.profile_memory = profile_memory
+        self.with_stack = with_stack
+        self.with_flops = with_flops
+        self.json_trace_path = json_trace_path
+
+    @master_only
+    def before_run(self, runner):
+        if self.by_epoch and runner.max_epochs < self.profile_iters:
+            raise ValueError('self.profile_iters should not be greater than '
+                             f'{runner.max_epochs}')
+
+        if not self.by_epoch and runner.max_iters < self.profile_iters:
+            raise ValueError('self.profile_iters should not be greater than '
+                             f'{runner.max_iters}')
+
+        if callable(self.on_trace_ready):  # handler
+            _on_trace_ready = self.on_trace_ready
+        elif isinstance(self.on_trace_ready, dict):  # config of handler
+            trace_cfg = self.on_trace_ready.copy()
+            trace_type = trace_cfg.pop('type')  # log_trace handler
+            if trace_type == 'log_trace':
+
+                def _log_handler(prof):
+                    print(prof.key_averages().table(**trace_cfg))
+
+                _on_trace_ready = _log_handler
+            elif trace_type == 'tb_trace':  # tensorboard_trace handler
+                try:
+                    import torch_tb_profiler  # noqa: F401
+                except ImportError:
+                    raise ImportError('please run "pip install '
+                                      'torch-tb-profiler" to install '
+                                      'torch_tb_profiler')
+                _on_trace_ready = torch.profiler.tensorboard_trace_handler(
+                    **trace_cfg)
+            else:
+                raise ValueError('trace_type should be "log_trace" or '
+                                 f'"tb_trace", but got {trace_type}')
+        elif self.on_trace_ready is None:
+            _on_trace_ready = None  # type: ignore
+        else:
+            raise ValueError('on_trace_ready should be handler, dict or None, '
+                             f'but got {type(self.on_trace_ready)}')
+
+        if runner.max_epochs > 1:
+            warnings.warn(f'profiler will profile {runner.max_epochs} epochs '
+                          'instead of 1 epoch. Since profiler will slow down '
+                          'the training, it is recommended to train 1 epoch '
+                          'with ProfilerHook and adjust your setting according'
+                          ' to the profiler summary. During normal training '
+                          '(epoch > 1), you may disable the ProfilerHook.')
+
+        self.profiler = torch.profiler.profile(
+            activities=self.activities,
+            schedule=self.schedule,
+            on_trace_ready=_on_trace_ready,
+            record_shapes=self.record_shapes,
+            profile_memory=self.profile_memory,
+            with_stack=self.with_stack,
+            with_flops=self.with_flops)
+
+        self.profiler.__enter__()
+        runner.logger.info('profiler is profiling...')
+
+    @master_only
+    def after_train_epoch(self, runner):
+        if self.by_epoch and runner.epoch == self.profile_iters - 1:
+            runner.logger.info('profiler may take a few minutes...')
+            self.profiler.__exit__(None, None, None)
+            if self.json_trace_path is not None:
+                self.profiler.export_chrome_trace(self.json_trace_path)
+
+    @master_only
+    def after_train_iter(self, runner):
+        self.profiler.step()
+        if not self.by_epoch and runner.iter == self.profile_iters - 1:
+            runner.logger.info('profiler may take a few minutes...')
+            self.profiler.__exit__(None, None, None)
+            if self.json_trace_path is not None:
+                self.profiler.export_chrome_trace(self.json_trace_path)
diff --git a/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py b/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee0dc6bdd8df5775857028aaed5444c0f59caf80
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/sampler_seed.py
@@ -0,0 +1,20 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class DistSamplerSeedHook(Hook):
+    """Data-loading sampler for distributed training.
+
+    When distributed training, it is only useful in conjunction with
+    :obj:`EpochBasedRunner`, while :obj:`IterBasedRunner` achieves the same
+    purpose with :obj:`IterLoader`.
+    """
+
+    def before_epoch(self, runner):
+        if hasattr(runner.data_loader.sampler, 'set_epoch'):
+            # in case the data loader uses `SequentialSampler` in Pytorch
+            runner.data_loader.sampler.set_epoch(runner.epoch)
+        elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
+            # batch sampler in pytorch warps the sampler as its attributes.
+            runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
diff --git a/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py b/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6376b7ff894280cb2782243b25e8973650591577
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/hooks/sync_buffer.py
@@ -0,0 +1,22 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from ..dist_utils import allreduce_params
+from .hook import HOOKS, Hook
+
+
+@HOOKS.register_module()
+class SyncBuffersHook(Hook):
+    """Synchronize model buffers such as running_mean and running_var in BN at
+    the end of each epoch.
+
+    Args:
+        distributed (bool): Whether distributed training is used. It is
+          effective only for distributed training. Defaults to True.
+    """
+
+    def __init__(self, distributed=True):
+        self.distributed = distributed
+
+    def after_epoch(self, runner):
+        """All-reduce model buffers at the end of each epoch."""
+        if self.distributed:
+            allreduce_params(runner.model.buffers())
diff --git a/annotator/uniformer/mmcv/runner/iter_based_runner.py b/annotator/uniformer/mmcv/runner/iter_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1df4de8c0285669dec9b014dfd1f3dd1600f0831
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/iter_based_runner.py
@@ -0,0 +1,273 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+import time
+import warnings
+
+import torch
+from torch.optim import Optimizer
+
+import annotator.uniformer.mmcv as mmcv
+from .base_runner import BaseRunner
+from .builder import RUNNERS
+from .checkpoint import save_checkpoint
+from .hooks import IterTimerHook
+from .utils import get_host_info
+
+
+class IterLoader:
+
+    def __init__(self, dataloader):
+        self._dataloader = dataloader
+        self.iter_loader = iter(self._dataloader)
+        self._epoch = 0
+
+    @property
+    def epoch(self):
+        return self._epoch
+
+    def __next__(self):
+        try:
+            data = next(self.iter_loader)
+        except StopIteration:
+            self._epoch += 1
+            if hasattr(self._dataloader.sampler, 'set_epoch'):
+                self._dataloader.sampler.set_epoch(self._epoch)
+            time.sleep(2)  # Prevent possible deadlock during epoch transition
+            self.iter_loader = iter(self._dataloader)
+            data = next(self.iter_loader)
+
+        return data
+
+    def __len__(self):
+        return len(self._dataloader)
+
+
+@RUNNERS.register_module()
+class IterBasedRunner(BaseRunner):
+    """Iteration-based Runner.
+
+    This runner train models iteration by iteration.
+    """
+
+    def train(self, data_loader, **kwargs):
+        self.model.train()
+        self.mode = 'train'
+        self.data_loader = data_loader
+        self._epoch = data_loader.epoch
+        data_batch = next(data_loader)
+        self.call_hook('before_train_iter')
+        outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
+        if not isinstance(outputs, dict):
+            raise TypeError('model.train_step() must return a dict')
+        if 'log_vars' in outputs:
+            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+        self.outputs = outputs
+        self.call_hook('after_train_iter')
+        self._inner_iter += 1
+        self._iter += 1
+
+    @torch.no_grad()
+    def val(self, data_loader, **kwargs):
+        self.model.eval()
+        self.mode = 'val'
+        self.data_loader = data_loader
+        data_batch = next(data_loader)
+        self.call_hook('before_val_iter')
+        outputs = self.model.val_step(data_batch, **kwargs)
+        if not isinstance(outputs, dict):
+            raise TypeError('model.val_step() must return a dict')
+        if 'log_vars' in outputs:
+            self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
+        self.outputs = outputs
+        self.call_hook('after_val_iter')
+        self._inner_iter += 1
+
+    def run(self, data_loaders, workflow, max_iters=None, **kwargs):
+        """Start running.
+
+        Args:
+            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
+                and validation.
+            workflow (list[tuple]): A list of (phase, iters) to specify the
+                running order and iterations. E.g, [('train', 10000),
+                ('val', 1000)] means running 10000 iterations for training and
+                1000 iterations for validation, iteratively.
+        """
+        assert isinstance(data_loaders, list)
+        assert mmcv.is_list_of(workflow, tuple)
+        assert len(data_loaders) == len(workflow)
+        if max_iters is not None:
+            warnings.warn(
+                'setting max_iters in run is deprecated, '
+                'please set max_iters in runner_config', DeprecationWarning)
+            self._max_iters = max_iters
+        assert self._max_iters is not None, (
+            'max_iters must be specified during instantiation')
+
+        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
+        self.logger.info('Start running, host: %s, work_dir: %s',
+                         get_host_info(), work_dir)
+        self.logger.info('Hooks will be executed in the following order:\n%s',
+                         self.get_hook_info())
+        self.logger.info('workflow: %s, max: %d iters', workflow,
+                         self._max_iters)
+        self.call_hook('before_run')
+
+        iter_loaders = [IterLoader(x) for x in data_loaders]
+
+        self.call_hook('before_epoch')
+
+        while self.iter < self._max_iters:
+            for i, flow in enumerate(workflow):
+                self._inner_iter = 0
+                mode, iters = flow
+                if not isinstance(mode, str) or not hasattr(self, mode):
+                    raise ValueError(
+                        'runner has no method named "{}" to run a workflow'.
+                        format(mode))
+                iter_runner = getattr(self, mode)
+                for _ in range(iters):
+                    if mode == 'train' and self.iter >= self._max_iters:
+                        break
+                    iter_runner(iter_loaders[i], **kwargs)
+
+        time.sleep(1)  # wait for some hooks like loggers to finish
+        self.call_hook('after_epoch')
+        self.call_hook('after_run')
+
+    def resume(self,
+               checkpoint,
+               resume_optimizer=True,
+               map_location='default'):
+        """Resume model from checkpoint.
+
+        Args:
+            checkpoint (str): Checkpoint to resume from.
+            resume_optimizer (bool, optional): Whether resume the optimizer(s)
+                if the checkpoint file includes optimizer(s). Default to True.
+            map_location (str, optional): Same as :func:`torch.load`.
+                Default to 'default'.
+        """
+        if map_location == 'default':
+            device_id = torch.cuda.current_device()
+            checkpoint = self.load_checkpoint(
+                checkpoint,
+                map_location=lambda storage, loc: storage.cuda(device_id))
+        else:
+            checkpoint = self.load_checkpoint(
+                checkpoint, map_location=map_location)
+
+        self._epoch = checkpoint['meta']['epoch']
+        self._iter = checkpoint['meta']['iter']
+        self._inner_iter = checkpoint['meta']['iter']
+        if 'optimizer' in checkpoint and resume_optimizer:
+            if isinstance(self.optimizer, Optimizer):
+                self.optimizer.load_state_dict(checkpoint['optimizer'])
+            elif isinstance(self.optimizer, dict):
+                for k in self.optimizer.keys():
+                    self.optimizer[k].load_state_dict(
+                        checkpoint['optimizer'][k])
+            else:
+                raise TypeError(
+                    'Optimizer should be dict or torch.optim.Optimizer '
+                    f'but got {type(self.optimizer)}')
+
+        self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}')
+
+    def save_checkpoint(self,
+                        out_dir,
+                        filename_tmpl='iter_{}.pth',
+                        meta=None,
+                        save_optimizer=True,
+                        create_symlink=True):
+        """Save checkpoint to file.
+
+        Args:
+            out_dir (str): Directory to save checkpoint files.
+            filename_tmpl (str, optional): Checkpoint file template.
+                Defaults to 'iter_{}.pth'.
+            meta (dict, optional): Metadata to be saved in checkpoint.
+                Defaults to None.
+            save_optimizer (bool, optional): Whether save optimizer.
+                Defaults to True.
+            create_symlink (bool, optional): Whether create symlink to the
+                latest checkpoint file. Defaults to True.
+        """
+        if meta is None:
+            meta = {}
+        elif not isinstance(meta, dict):
+            raise TypeError(
+                f'meta should be a dict or None, but got {type(meta)}')
+        if self.meta is not None:
+            meta.update(self.meta)
+            # Note: meta.update(self.meta) should be done before
+            # meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
+            # there will be problems with resumed checkpoints.
+            # More details in https://github.com/open-mmlab/mmcv/pull/1108
+        meta.update(epoch=self.epoch + 1, iter=self.iter)
+
+        filename = filename_tmpl.format(self.iter + 1)
+        filepath = osp.join(out_dir, filename)
+        optimizer = self.optimizer if save_optimizer else None
+        save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+        # in some environments, `os.symlink` is not supported, you may need to
+        # set `create_symlink` to False
+        if create_symlink:
+            dst_file = osp.join(out_dir, 'latest.pth')
+            if platform.system() != 'Windows':
+                mmcv.symlink(filename, dst_file)
+            else:
+                shutil.copy(filepath, dst_file)
+
+    def register_training_hooks(self,
+                                lr_config,
+                                optimizer_config=None,
+                                checkpoint_config=None,
+                                log_config=None,
+                                momentum_config=None,
+                                custom_hooks_config=None):
+        """Register default hooks for iter-based training.
+
+        Checkpoint hook, optimizer stepper hook and logger hooks will be set to
+        `by_epoch=False` by default.
+
+        Default hooks include:
+
+        +----------------------+-------------------------+
+        | Hooks                | Priority                |
+        +======================+=========================+
+        | LrUpdaterHook        | VERY_HIGH (10)          |
+        +----------------------+-------------------------+
+        | MomentumUpdaterHook  | HIGH (30)               |
+        +----------------------+-------------------------+
+        | OptimizerStepperHook | ABOVE_NORMAL (40)       |
+        +----------------------+-------------------------+
+        | CheckpointSaverHook  | NORMAL (50)             |
+        +----------------------+-------------------------+
+        | IterTimerHook        | LOW (70)                |
+        +----------------------+-------------------------+
+        | LoggerHook(s)        | VERY_LOW (90)           |
+        +----------------------+-------------------------+
+        | CustomHook(s)        | defaults to NORMAL (50) |
+        +----------------------+-------------------------+
+
+        If custom hooks have same priority with default hooks, custom hooks
+        will be triggered after default hooks.
+        """
+        if checkpoint_config is not None:
+            checkpoint_config.setdefault('by_epoch', False)
+        if lr_config is not None:
+            lr_config.setdefault('by_epoch', False)
+        if log_config is not None:
+            for info in log_config['hooks']:
+                info.setdefault('by_epoch', False)
+        super(IterBasedRunner, self).register_training_hooks(
+            lr_config=lr_config,
+            momentum_config=momentum_config,
+            optimizer_config=optimizer_config,
+            checkpoint_config=checkpoint_config,
+            log_config=log_config,
+            timer_config=IterTimerHook(),
+            custom_hooks_config=custom_hooks_config)
diff --git a/annotator/uniformer/mmcv/runner/log_buffer.py b/annotator/uniformer/mmcv/runner/log_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d949e2941c5400088c7cd8a1dc893d8b233ae785
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/log_buffer.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+
+import numpy as np
+
+
+class LogBuffer:
+
+    def __init__(self):
+        self.val_history = OrderedDict()
+        self.n_history = OrderedDict()
+        self.output = OrderedDict()
+        self.ready = False
+
+    def clear(self):
+        self.val_history.clear()
+        self.n_history.clear()
+        self.clear_output()
+
+    def clear_output(self):
+        self.output.clear()
+        self.ready = False
+
+    def update(self, vars, count=1):
+        assert isinstance(vars, dict)
+        for key, var in vars.items():
+            if key not in self.val_history:
+                self.val_history[key] = []
+                self.n_history[key] = []
+            self.val_history[key].append(var)
+            self.n_history[key].append(count)
+
+    def average(self, n=0):
+        """Average latest n values or all values."""
+        assert n >= 0
+        for key in self.val_history:
+            values = np.array(self.val_history[key][-n:])
+            nums = np.array(self.n_history[key][-n:])
+            avg = np.sum(values * nums) / np.sum(nums)
+            self.output[key] = avg
+        self.ready = True
diff --git a/annotator/uniformer/mmcv/runner/optimizer/__init__.py b/annotator/uniformer/mmcv/runner/optimizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53c34d0470992cbc374f29681fdd00dc0e57968d
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/optimizer/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .builder import (OPTIMIZER_BUILDERS, OPTIMIZERS, build_optimizer,
+                      build_optimizer_constructor)
+from .default_constructor import DefaultOptimizerConstructor
+
+__all__ = [
+    'OPTIMIZER_BUILDERS', 'OPTIMIZERS', 'DefaultOptimizerConstructor',
+    'build_optimizer', 'build_optimizer_constructor'
+]
diff --git a/annotator/uniformer/mmcv/runner/optimizer/builder.py b/annotator/uniformer/mmcv/runner/optimizer/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9234eed8f1f186d9d8dfda34562157ee39bdb3a
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/optimizer/builder.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import inspect
+
+import torch
+
+from ...utils import Registry, build_from_cfg
+
+OPTIMIZERS = Registry('optimizer')
+OPTIMIZER_BUILDERS = Registry('optimizer builder')
+
+
+def register_torch_optimizers():
+    torch_optimizers = []
+    for module_name in dir(torch.optim):
+        if module_name.startswith('__'):
+            continue
+        _optim = getattr(torch.optim, module_name)
+        if inspect.isclass(_optim) and issubclass(_optim,
+                                                  torch.optim.Optimizer):
+            OPTIMIZERS.register_module()(_optim)
+            torch_optimizers.append(module_name)
+    return torch_optimizers
+
+
+TORCH_OPTIMIZERS = register_torch_optimizers()
+
+
+def build_optimizer_constructor(cfg):
+    return build_from_cfg(cfg, OPTIMIZER_BUILDERS)
+
+
+def build_optimizer(model, cfg):
+    optimizer_cfg = copy.deepcopy(cfg)
+    constructor_type = optimizer_cfg.pop('constructor',
+                                         'DefaultOptimizerConstructor')
+    paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None)
+    optim_constructor = build_optimizer_constructor(
+        dict(
+            type=constructor_type,
+            optimizer_cfg=optimizer_cfg,
+            paramwise_cfg=paramwise_cfg))
+    optimizer = optim_constructor(model)
+    return optimizer
diff --git a/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py b/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c0da3503b75441738efe38d70352b55a210a34a
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/optimizer/default_constructor.py
@@ -0,0 +1,249 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import torch
+from torch.nn import GroupNorm, LayerNorm
+
+from annotator.uniformer.mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg, is_list_of
+from annotator.uniformer.mmcv.utils.ext_loader import check_ops_exist
+from .builder import OPTIMIZER_BUILDERS, OPTIMIZERS
+
+
+@OPTIMIZER_BUILDERS.register_module()
+class DefaultOptimizerConstructor:
+    """Default constructor for optimizers.
+
+    By default each parameter share the same optimizer settings, and we
+    provide an argument ``paramwise_cfg`` to specify parameter-wise settings.
+    It is a dict and may contain the following fields:
+
+    - ``custom_keys`` (dict): Specified parameters-wise settings by keys. If
+      one of the keys in ``custom_keys`` is a substring of the name of one
+      parameter, then the setting of the parameter will be specified by
+      ``custom_keys[key]`` and other setting like ``bias_lr_mult`` etc. will
+      be ignored. It should be noted that the aforementioned ``key`` is the
+      longest key that is a substring of the name of the parameter. If there
+      are multiple matched keys with the same length, then the key with lower
+      alphabet order will be chosen.
+      ``custom_keys[key]`` should be a dict and may contain fields ``lr_mult``
+      and ``decay_mult``. See Example 2 below.
+    - ``bias_lr_mult`` (float): It will be multiplied to the learning
+      rate for all bias parameters (except for those in normalization
+      layers and offset layers of DCN).
+    - ``bias_decay_mult`` (float): It will be multiplied to the weight
+      decay for all bias parameters (except for those in
+      normalization layers, depthwise conv layers, offset layers of DCN).
+    - ``norm_decay_mult`` (float): It will be multiplied to the weight
+      decay for all weight and bias parameters of normalization
+      layers.
+    - ``dwconv_decay_mult`` (float): It will be multiplied to the weight
+      decay for all weight and bias parameters of depthwise conv
+      layers.
+    - ``dcn_offset_lr_mult`` (float): It will be multiplied to the learning
+      rate for parameters of offset layer in the deformable convs
+      of a model.
+    - ``bypass_duplicate`` (bool): If true, the duplicate parameters
+      would not be added into optimizer. Default: False.
+
+    Note:
+        1. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+            override the effect of ``bias_lr_mult`` in the bias of offset
+            layer. So be careful when using both ``bias_lr_mult`` and
+            ``dcn_offset_lr_mult``. If you wish to apply both of them to the
+            offset layer in deformable convs, set ``dcn_offset_lr_mult``
+            to the original ``dcn_offset_lr_mult`` * ``bias_lr_mult``.
+        2. If the option ``dcn_offset_lr_mult`` is used, the constructor will
+            apply it to all the DCN layers in the model. So be careful when
+            the model contains multiple DCN layers in places other than
+            backbone.
+
+    Args:
+        model (:obj:`nn.Module`): The model with parameters to be optimized.
+        optimizer_cfg (dict): The config dict of the optimizer.
+            Positional fields are
+
+                - `type`: class name of the optimizer.
+
+            Optional fields are
+
+                - any arguments of the corresponding optimizer type, e.g.,
+                  lr, weight_decay, momentum, etc.
+        paramwise_cfg (dict, optional): Parameter-wise options.
+
+    Example 1:
+        >>> model = torch.nn.modules.Conv1d(1, 1, 1)
+        >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9,
+        >>>                      weight_decay=0.0001)
+        >>> paramwise_cfg = dict(norm_decay_mult=0.)
+        >>> optim_builder = DefaultOptimizerConstructor(
+        >>>     optimizer_cfg, paramwise_cfg)
+        >>> optimizer = optim_builder(model)
+
+    Example 2:
+        >>> # assume model have attribute model.backbone and model.cls_head
+        >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95)
+        >>> paramwise_cfg = dict(custom_keys={
+                '.backbone': dict(lr_mult=0.1, decay_mult=0.9)})
+        >>> optim_builder = DefaultOptimizerConstructor(
+        >>>     optimizer_cfg, paramwise_cfg)
+        >>> optimizer = optim_builder(model)
+        >>> # Then the `lr` and `weight_decay` for model.backbone is
+        >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for
+        >>> # model.cls_head is (0.01, 0.95).
+    """
+
+    def __init__(self, optimizer_cfg, paramwise_cfg=None):
+        if not isinstance(optimizer_cfg, dict):
+            raise TypeError('optimizer_cfg should be a dict',
+                            f'but got {type(optimizer_cfg)}')
+        self.optimizer_cfg = optimizer_cfg
+        self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg
+        self.base_lr = optimizer_cfg.get('lr', None)
+        self.base_wd = optimizer_cfg.get('weight_decay', None)
+        self._validate_cfg()
+
+    def _validate_cfg(self):
+        if not isinstance(self.paramwise_cfg, dict):
+            raise TypeError('paramwise_cfg should be None or a dict, '
+                            f'but got {type(self.paramwise_cfg)}')
+
+        if 'custom_keys' in self.paramwise_cfg:
+            if not isinstance(self.paramwise_cfg['custom_keys'], dict):
+                raise TypeError(
+                    'If specified, custom_keys must be a dict, '
+                    f'but got {type(self.paramwise_cfg["custom_keys"])}')
+            if self.base_wd is None:
+                for key in self.paramwise_cfg['custom_keys']:
+                    if 'decay_mult' in self.paramwise_cfg['custom_keys'][key]:
+                        raise ValueError('base_wd should not be None')
+
+        # get base lr and weight decay
+        # weight_decay must be explicitly specified if mult is specified
+        if ('bias_decay_mult' in self.paramwise_cfg
+                or 'norm_decay_mult' in self.paramwise_cfg
+                or 'dwconv_decay_mult' in self.paramwise_cfg):
+            if self.base_wd is None:
+                raise ValueError('base_wd should not be None')
+
+    def _is_in(self, param_group, param_group_list):
+        assert is_list_of(param_group_list, dict)
+        param = set(param_group['params'])
+        param_set = set()
+        for group in param_group_list:
+            param_set.update(set(group['params']))
+
+        return not param.isdisjoint(param_set)
+
+    def add_params(self, params, module, prefix='', is_dcn_module=None):
+        """Add all parameters of module to the params list.
+
+        The parameters of the given module will be added to the list of param
+        groups, with specific rules defined by paramwise_cfg.
+
+        Args:
+            params (list[dict]): A list of param groups, it will be modified
+                in place.
+            module (nn.Module): The module to be added.
+            prefix (str): The prefix of the module
+            is_dcn_module (int|float|None): If the current module is a
+                submodule of DCN, `is_dcn_module` will be passed to
+                control conv_offset layer's learning rate. Defaults to None.
+        """
+        # get param-wise options
+        custom_keys = self.paramwise_cfg.get('custom_keys', {})
+        # first sort with alphabet order and then sort with reversed len of str
+        sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
+
+        bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
+        bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
+        norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
+        dwconv_decay_mult = self.paramwise_cfg.get('dwconv_decay_mult', 1.)
+        bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
+        dcn_offset_lr_mult = self.paramwise_cfg.get('dcn_offset_lr_mult', 1.)
+
+        # special rules for norm layers and depth-wise conv layers
+        is_norm = isinstance(module,
+                             (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
+        is_dwconv = (
+            isinstance(module, torch.nn.Conv2d)
+            and module.in_channels == module.groups)
+
+        for name, param in module.named_parameters(recurse=False):
+            param_group = {'params': [param]}
+            if not param.requires_grad:
+                params.append(param_group)
+                continue
+            if bypass_duplicate and self._is_in(param_group, params):
+                warnings.warn(f'{prefix} is duplicate. It is skipped since '
+                              f'bypass_duplicate={bypass_duplicate}')
+                continue
+            # if the parameter match one of the custom keys, ignore other rules
+            is_custom = False
+            for key in sorted_keys:
+                if key in f'{prefix}.{name}':
+                    is_custom = True
+                    lr_mult = custom_keys[key].get('lr_mult', 1.)
+                    param_group['lr'] = self.base_lr * lr_mult
+                    if self.base_wd is not None:
+                        decay_mult = custom_keys[key].get('decay_mult', 1.)
+                        param_group['weight_decay'] = self.base_wd * decay_mult
+                    break
+
+            if not is_custom:
+                # bias_lr_mult affects all bias parameters
+                # except for norm.bias dcn.conv_offset.bias
+                if name == 'bias' and not (is_norm or is_dcn_module):
+                    param_group['lr'] = self.base_lr * bias_lr_mult
+
+                if (prefix.find('conv_offset') != -1 and is_dcn_module
+                        and isinstance(module, torch.nn.Conv2d)):
+                    # deal with both dcn_offset's bias & weight
+                    param_group['lr'] = self.base_lr * dcn_offset_lr_mult
+
+                # apply weight decay policies
+                if self.base_wd is not None:
+                    # norm decay
+                    if is_norm:
+                        param_group[
+                            'weight_decay'] = self.base_wd * norm_decay_mult
+                    # depth-wise conv
+                    elif is_dwconv:
+                        param_group[
+                            'weight_decay'] = self.base_wd * dwconv_decay_mult
+                    # bias lr and decay
+                    elif name == 'bias' and not is_dcn_module:
+                        # TODO: current bias_decay_mult will have affect on DCN
+                        param_group[
+                            'weight_decay'] = self.base_wd * bias_decay_mult
+            params.append(param_group)
+
+        if check_ops_exist():
+            from annotator.uniformer.mmcv.ops import DeformConv2d, ModulatedDeformConv2d
+            is_dcn_module = isinstance(module,
+                                       (DeformConv2d, ModulatedDeformConv2d))
+        else:
+            is_dcn_module = False
+        for child_name, child_mod in module.named_children():
+            child_prefix = f'{prefix}.{child_name}' if prefix else child_name
+            self.add_params(
+                params,
+                child_mod,
+                prefix=child_prefix,
+                is_dcn_module=is_dcn_module)
+
+    def __call__(self, model):
+        if hasattr(model, 'module'):
+            model = model.module
+
+        optimizer_cfg = self.optimizer_cfg.copy()
+        # if no paramwise option is specified, just use the global setting
+        if not self.paramwise_cfg:
+            optimizer_cfg['params'] = model.parameters()
+            return build_from_cfg(optimizer_cfg, OPTIMIZERS)
+
+        # set param-wise lr and weight decay recursively
+        params = []
+        self.add_params(params, model)
+        optimizer_cfg['params'] = params
+
+        return build_from_cfg(optimizer_cfg, OPTIMIZERS)
diff --git a/annotator/uniformer/mmcv/runner/priority.py b/annotator/uniformer/mmcv/runner/priority.py
new file mode 100644
index 0000000000000000000000000000000000000000..64cc4e3a05f8d5b89ab6eb32461e6e80f1d62e67
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/priority.py
@@ -0,0 +1,60 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+
+
+class Priority(Enum):
+    """Hook priority levels.
+
+    +--------------+------------+
+    | Level        | Value      |
+    +==============+============+
+    | HIGHEST      | 0          |
+    +--------------+------------+
+    | VERY_HIGH    | 10         |
+    +--------------+------------+
+    | HIGH         | 30         |
+    +--------------+------------+
+    | ABOVE_NORMAL | 40         |
+    +--------------+------------+
+    | NORMAL       | 50         |
+    +--------------+------------+
+    | BELOW_NORMAL | 60         |
+    +--------------+------------+
+    | LOW          | 70         |
+    +--------------+------------+
+    | VERY_LOW     | 90         |
+    +--------------+------------+
+    | LOWEST       | 100        |
+    +--------------+------------+
+    """
+
+    HIGHEST = 0
+    VERY_HIGH = 10
+    HIGH = 30
+    ABOVE_NORMAL = 40
+    NORMAL = 50
+    BELOW_NORMAL = 60
+    LOW = 70
+    VERY_LOW = 90
+    LOWEST = 100
+
+
+def get_priority(priority):
+    """Get priority value.
+
+    Args:
+        priority (int or str or :obj:`Priority`): Priority.
+
+    Returns:
+        int: The priority value.
+    """
+    if isinstance(priority, int):
+        if priority < 0 or priority > 100:
+            raise ValueError('priority must be between 0 and 100')
+        return priority
+    elif isinstance(priority, Priority):
+        return priority.value
+    elif isinstance(priority, str):
+        return Priority[priority.upper()].value
+    else:
+        raise TypeError('priority must be an integer or Priority enum value')
diff --git a/annotator/uniformer/mmcv/runner/utils.py b/annotator/uniformer/mmcv/runner/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5befb8e56ece50b5fecfd007b26f8a29124c0bd
--- /dev/null
+++ b/annotator/uniformer/mmcv/runner/utils.py
@@ -0,0 +1,93 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import random
+import sys
+import time
+import warnings
+from getpass import getuser
+from socket import gethostname
+
+import numpy as np
+import torch
+
+import annotator.uniformer.mmcv as mmcv
+
+
+def get_host_info():
+    """Get hostname and username.
+
+    Return empty string if exception raised, e.g. ``getpass.getuser()`` will
+    lead to error in docker container
+    """
+    host = ''
+    try:
+        host = f'{getuser()}@{gethostname()}'
+    except Exception as e:
+        warnings.warn(f'Host or user not found: {str(e)}')
+    finally:
+        return host
+
+
+def get_time_str():
+    return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def obj_from_dict(info, parent=None, default_args=None):
+    """Initialize an object from dict.
+
+    The dict must contain the key "type", which indicates the object type, it
+    can be either a string or type, such as "list" or ``list``. Remaining
+    fields are treated as the arguments for constructing the object.
+
+    Args:
+        info (dict): Object types and arguments.
+        parent (:class:`module`): Module which may containing expected object
+            classes.
+        default_args (dict, optional): Default arguments for initializing the
+            object.
+
+    Returns:
+        any type: Object built from the dict.
+    """
+    assert isinstance(info, dict) and 'type' in info
+    assert isinstance(default_args, dict) or default_args is None
+    args = info.copy()
+    obj_type = args.pop('type')
+    if mmcv.is_str(obj_type):
+        if parent is not None:
+            obj_type = getattr(parent, obj_type)
+        else:
+            obj_type = sys.modules[obj_type]
+    elif not isinstance(obj_type, type):
+        raise TypeError('type must be a str or valid type, but '
+                        f'got {type(obj_type)}')
+    if default_args is not None:
+        for name, value in default_args.items():
+            args.setdefault(name, value)
+    return obj_type(**args)
+
+
+def set_random_seed(seed, deterministic=False, use_rank_shift=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+        rank_shift (bool): Whether to add rank number to the random seed to
+            have different random seed in different threads. Default: False.
+    """
+    if use_rank_shift:
+        rank, _ = mmcv.runner.get_dist_info()
+        seed += rank
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    os.environ['PYTHONHASHSEED'] = str(seed)
+    if deterministic:
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
diff --git a/annotator/uniformer/mmcv/utils/__init__.py b/annotator/uniformer/mmcv/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..378a0068432a371af364de9d73785901c0f83383
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/__init__.py
@@ -0,0 +1,69 @@
+# flake8: noqa
+# Copyright (c) OpenMMLab. All rights reserved.
+from .config import Config, ConfigDict, DictAction
+from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
+                   has_method, import_modules_from_strings, is_list_of,
+                   is_method_overridden, is_seq_of, is_str, is_tuple_of,
+                   iter_cast, list_cast, requires_executable, requires_package,
+                   slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
+                   to_ntuple, tuple_cast)
+from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
+                   scandir, symlink)
+from .progressbar import (ProgressBar, track_iter_progress,
+                          track_parallel_progress, track_progress)
+from .testing import (assert_attrs_equal, assert_dict_contains_subset,
+                      assert_dict_has_keys, assert_is_norm_layer,
+                      assert_keys_equal, assert_params_all_zeros,
+                      check_python_script)
+from .timer import Timer, TimerError, check_time
+from .version_utils import digit_version, get_git_hash
+
+try:
+    import torch
+except ImportError:
+    __all__ = [
+        'Config', 'ConfigDict', 'DictAction', 'is_str', 'iter_cast',
+        'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
+        'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
+        'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
+        'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
+        'track_progress', 'track_iter_progress', 'track_parallel_progress',
+        'Timer', 'TimerError', 'check_time', 'deprecated_api_warning',
+        'digit_version', 'get_git_hash', 'import_modules_from_strings',
+        'assert_dict_contains_subset', 'assert_attrs_equal',
+        'assert_dict_has_keys', 'assert_keys_equal', 'check_python_script',
+        'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
+        'is_method_overridden', 'has_method'
+    ]
+else:
+    from .env import collect_env
+    from .logging import get_logger, print_log
+    from .parrots_jit import jit, skip_no_elena
+    from .parrots_wrapper import (
+        TORCH_VERSION, BuildExtension, CppExtension, CUDAExtension, DataLoader,
+        PoolDataLoader, SyncBatchNorm, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
+        _AvgPoolNd, _BatchNorm, _ConvNd, _ConvTransposeMixin, _InstanceNorm,
+        _MaxPoolNd, get_build_config, is_rocm_pytorch, _get_cuda_home)
+    from .registry import Registry, build_from_cfg
+    from .trace import is_jit_tracing
+    __all__ = [
+        'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
+        'print_log', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
+        'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
+        'check_prerequisites', 'requires_package', 'requires_executable',
+        'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist',
+        'symlink', 'scandir', 'ProgressBar', 'track_progress',
+        'track_iter_progress', 'track_parallel_progress', 'Registry',
+        'build_from_cfg', 'Timer', 'TimerError', 'check_time', 'SyncBatchNorm',
+        '_AdaptiveAvgPoolNd', '_AdaptiveMaxPoolNd', '_AvgPoolNd', '_BatchNorm',
+        '_ConvNd', '_ConvTransposeMixin', '_InstanceNorm', '_MaxPoolNd',
+        'get_build_config', 'BuildExtension', 'CppExtension', 'CUDAExtension',
+        'DataLoader', 'PoolDataLoader', 'TORCH_VERSION',
+        'deprecated_api_warning', 'digit_version', 'get_git_hash',
+        'import_modules_from_strings', 'jit', 'skip_no_elena',
+        'assert_dict_contains_subset', 'assert_attrs_equal',
+        'assert_dict_has_keys', 'assert_keys_equal', 'assert_is_norm_layer',
+        'assert_params_all_zeros', 'check_python_script',
+        'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
+        '_get_cuda_home', 'has_method'
+    ]
diff --git a/annotator/uniformer/mmcv/utils/config.py b/annotator/uniformer/mmcv/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..17149353aefac6d737c67bb2f35a3a6cd2147b0a
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/config.py
@@ -0,0 +1,688 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import ast
+import copy
+import os
+import os.path as osp
+import platform
+import shutil
+import sys
+import tempfile
+import uuid
+import warnings
+from argparse import Action, ArgumentParser
+from collections import abc
+from importlib import import_module
+
+from addict import Dict
+from yapf.yapflib.yapf_api import FormatCode
+
+from .misc import import_modules_from_strings
+from .path import check_file_exist
+
+if platform.system() == 'Windows':
+    import regex as re
+else:
+    import re
+
+BASE_KEY = '_base_'
+DELETE_KEY = '_delete_'
+DEPRECATION_KEY = '_deprecation_'
+RESERVED_KEYS = ['filename', 'text', 'pretty_text']
+
+
+class ConfigDict(Dict):
+
+    def __missing__(self, name):
+        raise KeyError(name)
+
+    def __getattr__(self, name):
+        try:
+            value = super(ConfigDict, self).__getattr__(name)
+        except KeyError:
+            ex = AttributeError(f"'{self.__class__.__name__}' object has no "
+                                f"attribute '{name}'")
+        except Exception as e:
+            ex = e
+        else:
+            return value
+        raise ex
+
+
+def add_args(parser, cfg, prefix=''):
+    for k, v in cfg.items():
+        if isinstance(v, str):
+            parser.add_argument('--' + prefix + k)
+        elif isinstance(v, int):
+            parser.add_argument('--' + prefix + k, type=int)
+        elif isinstance(v, float):
+            parser.add_argument('--' + prefix + k, type=float)
+        elif isinstance(v, bool):
+            parser.add_argument('--' + prefix + k, action='store_true')
+        elif isinstance(v, dict):
+            add_args(parser, v, prefix + k + '.')
+        elif isinstance(v, abc.Iterable):
+            parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
+        else:
+            print(f'cannot parse key {prefix + k} of type {type(v)}')
+    return parser
+
+
+class Config:
+    """A facility for config and config files.
+
+    It supports common file formats as configs: python/json/yaml. The interface
+    is the same as a dict object and also allows access config values as
+    attributes.
+
+    Example:
+        >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+        >>> cfg.a
+        1
+        >>> cfg.b
+        {'b1': [0, 1]}
+        >>> cfg.b.b1
+        [0, 1]
+        >>> cfg = Config.fromfile('tests/data/config/a.py')
+        >>> cfg.filename
+        "/home/kchen/projects/mmcv/tests/data/config/a.py"
+        >>> cfg.item4
+        'test'
+        >>> cfg
+        "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
+        "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+    """
+
+    @staticmethod
+    def _validate_py_syntax(filename):
+        with open(filename, 'r', encoding='utf-8') as f:
+            # Setting encoding explicitly to resolve coding issue on windows
+            content = f.read()
+        try:
+            ast.parse(content)
+        except SyntaxError as e:
+            raise SyntaxError('There are syntax errors in config '
+                              f'file {filename}: {e}')
+
+    @staticmethod
+    def _substitute_predefined_vars(filename, temp_config_name):
+        file_dirname = osp.dirname(filename)
+        file_basename = osp.basename(filename)
+        file_basename_no_extension = osp.splitext(file_basename)[0]
+        file_extname = osp.splitext(filename)[1]
+        support_templates = dict(
+            fileDirname=file_dirname,
+            fileBasename=file_basename,
+            fileBasenameNoExtension=file_basename_no_extension,
+            fileExtname=file_extname)
+        with open(filename, 'r', encoding='utf-8') as f:
+            # Setting encoding explicitly to resolve coding issue on windows
+            config_file = f.read()
+        for key, value in support_templates.items():
+            regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
+            value = value.replace('\\', '/')
+            config_file = re.sub(regexp, value, config_file)
+        with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+            tmp_config_file.write(config_file)
+
+    @staticmethod
+    def _pre_substitute_base_vars(filename, temp_config_name):
+        """Substitute base variable placehoders to string, so that parsing
+        would work."""
+        with open(filename, 'r', encoding='utf-8') as f:
+            # Setting encoding explicitly to resolve coding issue on windows
+            config_file = f.read()
+        base_var_dict = {}
+        regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
+        base_vars = set(re.findall(regexp, config_file))
+        for base_var in base_vars:
+            randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
+            base_var_dict[randstr] = base_var
+            regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
+            config_file = re.sub(regexp, f'"{randstr}"', config_file)
+        with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file:
+            tmp_config_file.write(config_file)
+        return base_var_dict
+
+    @staticmethod
+    def _substitute_base_vars(cfg, base_var_dict, base_cfg):
+        """Substitute variable strings to their actual values."""
+        cfg = copy.deepcopy(cfg)
+
+        if isinstance(cfg, dict):
+            for k, v in cfg.items():
+                if isinstance(v, str) and v in base_var_dict:
+                    new_v = base_cfg
+                    for new_k in base_var_dict[v].split('.'):
+                        new_v = new_v[new_k]
+                    cfg[k] = new_v
+                elif isinstance(v, (list, tuple, dict)):
+                    cfg[k] = Config._substitute_base_vars(
+                        v, base_var_dict, base_cfg)
+        elif isinstance(cfg, tuple):
+            cfg = tuple(
+                Config._substitute_base_vars(c, base_var_dict, base_cfg)
+                for c in cfg)
+        elif isinstance(cfg, list):
+            cfg = [
+                Config._substitute_base_vars(c, base_var_dict, base_cfg)
+                for c in cfg
+            ]
+        elif isinstance(cfg, str) and cfg in base_var_dict:
+            new_v = base_cfg
+            for new_k in base_var_dict[cfg].split('.'):
+                new_v = new_v[new_k]
+            cfg = new_v
+
+        return cfg
+
+    @staticmethod
+    def _file2dict(filename, use_predefined_variables=True):
+        filename = osp.abspath(osp.expanduser(filename))
+        check_file_exist(filename)
+        fileExtname = osp.splitext(filename)[1]
+        if fileExtname not in ['.py', '.json', '.yaml', '.yml']:
+            raise IOError('Only py/yml/yaml/json type are supported now!')
+
+        with tempfile.TemporaryDirectory() as temp_config_dir:
+            temp_config_file = tempfile.NamedTemporaryFile(
+                dir=temp_config_dir, suffix=fileExtname)
+            if platform.system() == 'Windows':
+                temp_config_file.close()
+            temp_config_name = osp.basename(temp_config_file.name)
+            # Substitute predefined variables
+            if use_predefined_variables:
+                Config._substitute_predefined_vars(filename,
+                                                   temp_config_file.name)
+            else:
+                shutil.copyfile(filename, temp_config_file.name)
+            # Substitute base variables from placeholders to strings
+            base_var_dict = Config._pre_substitute_base_vars(
+                temp_config_file.name, temp_config_file.name)
+
+            if filename.endswith('.py'):
+                temp_module_name = osp.splitext(temp_config_name)[0]
+                sys.path.insert(0, temp_config_dir)
+                Config._validate_py_syntax(filename)
+                mod = import_module(temp_module_name)
+                sys.path.pop(0)
+                cfg_dict = {
+                    name: value
+                    for name, value in mod.__dict__.items()
+                    if not name.startswith('__')
+                }
+                # delete imported module
+                del sys.modules[temp_module_name]
+            elif filename.endswith(('.yml', '.yaml', '.json')):
+                import annotator.uniformer.mmcv as mmcv
+                cfg_dict = mmcv.load(temp_config_file.name)
+            # close temp file
+            temp_config_file.close()
+
+        # check deprecation information
+        if DEPRECATION_KEY in cfg_dict:
+            deprecation_info = cfg_dict.pop(DEPRECATION_KEY)
+            warning_msg = f'The config file {filename} will be deprecated ' \
+                'in the future.'
+            if 'expected' in deprecation_info:
+                warning_msg += f' Please use {deprecation_info["expected"]} ' \
+                    'instead.'
+            if 'reference' in deprecation_info:
+                warning_msg += ' More information can be found at ' \
+                    f'{deprecation_info["reference"]}'
+            warnings.warn(warning_msg)
+
+        cfg_text = filename + '\n'
+        with open(filename, 'r', encoding='utf-8') as f:
+            # Setting encoding explicitly to resolve coding issue on windows
+            cfg_text += f.read()
+
+        if BASE_KEY in cfg_dict:
+            cfg_dir = osp.dirname(filename)
+            base_filename = cfg_dict.pop(BASE_KEY)
+            base_filename = base_filename if isinstance(
+                base_filename, list) else [base_filename]
+
+            cfg_dict_list = list()
+            cfg_text_list = list()
+            for f in base_filename:
+                _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
+                cfg_dict_list.append(_cfg_dict)
+                cfg_text_list.append(_cfg_text)
+
+            base_cfg_dict = dict()
+            for c in cfg_dict_list:
+                duplicate_keys = base_cfg_dict.keys() & c.keys()
+                if len(duplicate_keys) > 0:
+                    raise KeyError('Duplicate key is not allowed among bases. '
+                                   f'Duplicate keys: {duplicate_keys}')
+                base_cfg_dict.update(c)
+
+            # Substitute base variables from strings to their actual values
+            cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
+                                                    base_cfg_dict)
+
+            base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
+            cfg_dict = base_cfg_dict
+
+            # merge cfg_text
+            cfg_text_list.append(cfg_text)
+            cfg_text = '\n'.join(cfg_text_list)
+
+        return cfg_dict, cfg_text
+
+    @staticmethod
+    def _merge_a_into_b(a, b, allow_list_keys=False):
+        """merge dict ``a`` into dict ``b`` (non-inplace).
+
+        Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
+        in-place modifications.
+
+        Args:
+            a (dict): The source dict to be merged into ``b``.
+            b (dict): The origin dict to be fetch keys from ``a``.
+            allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+              are allowed in source ``a`` and will replace the element of the
+              corresponding index in b if b is a list. Default: False.
+
+        Returns:
+            dict: The modified dict of ``b`` using ``a``.
+
+        Examples:
+            # Normally merge a into b.
+            >>> Config._merge_a_into_b(
+            ...     dict(obj=dict(a=2)), dict(obj=dict(a=1)))
+            {'obj': {'a': 2}}
+
+            # Delete b first and merge a into b.
+            >>> Config._merge_a_into_b(
+            ...     dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1)))
+            {'obj': {'a': 2}}
+
+            # b is a list
+            >>> Config._merge_a_into_b(
+            ...     {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True)
+            [{'a': 2}, {'b': 2}]
+        """
+        b = b.copy()
+        for k, v in a.items():
+            if allow_list_keys and k.isdigit() and isinstance(b, list):
+                k = int(k)
+                if len(b) <= k:
+                    raise KeyError(f'Index {k} exceeds the length of list {b}')
+                b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+            elif isinstance(v,
+                            dict) and k in b and not v.pop(DELETE_KEY, False):
+                allowed_types = (dict, list) if allow_list_keys else dict
+                if not isinstance(b[k], allowed_types):
+                    raise TypeError(
+                        f'{k}={v} in child config cannot inherit from base '
+                        f'because {k} is a dict in the child config but is of '
+                        f'type {type(b[k])} in base config. You may set '
+                        f'`{DELETE_KEY}=True` to ignore the base config')
+                b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys)
+            else:
+                b[k] = v
+        return b
+
+    @staticmethod
+    def fromfile(filename,
+                 use_predefined_variables=True,
+                 import_custom_modules=True):
+        cfg_dict, cfg_text = Config._file2dict(filename,
+                                               use_predefined_variables)
+        if import_custom_modules and cfg_dict.get('custom_imports', None):
+            import_modules_from_strings(**cfg_dict['custom_imports'])
+        return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
+
+    @staticmethod
+    def fromstring(cfg_str, file_format):
+        """Generate config from config str.
+
+        Args:
+            cfg_str (str): Config str.
+            file_format (str): Config file format corresponding to the
+               config str. Only py/yml/yaml/json type are supported now!
+
+        Returns:
+            obj:`Config`: Config obj.
+        """
+        if file_format not in ['.py', '.json', '.yaml', '.yml']:
+            raise IOError('Only py/yml/yaml/json type are supported now!')
+        if file_format != '.py' and 'dict(' in cfg_str:
+            # check if users specify a wrong suffix for python
+            warnings.warn(
+                'Please check "file_format", the file format may be .py')
+        with tempfile.NamedTemporaryFile(
+                'w', encoding='utf-8', suffix=file_format,
+                delete=False) as temp_file:
+            temp_file.write(cfg_str)
+            # on windows, previous implementation cause error
+            # see PR 1077 for details
+        cfg = Config.fromfile(temp_file.name)
+        os.remove(temp_file.name)
+        return cfg
+
+    @staticmethod
+    def auto_argparser(description=None):
+        """Generate argparser from config file automatically (experimental)"""
+        partial_parser = ArgumentParser(description=description)
+        partial_parser.add_argument('config', help='config file path')
+        cfg_file = partial_parser.parse_known_args()[0].config
+        cfg = Config.fromfile(cfg_file)
+        parser = ArgumentParser(description=description)
+        parser.add_argument('config', help='config file path')
+        add_args(parser, cfg)
+        return parser, cfg
+
+    def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
+        if cfg_dict is None:
+            cfg_dict = dict()
+        elif not isinstance(cfg_dict, dict):
+            raise TypeError('cfg_dict must be a dict, but '
+                            f'got {type(cfg_dict)}')
+        for key in cfg_dict:
+            if key in RESERVED_KEYS:
+                raise KeyError(f'{key} is reserved for config file')
+
+        super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
+        super(Config, self).__setattr__('_filename', filename)
+        if cfg_text:
+            text = cfg_text
+        elif filename:
+            with open(filename, 'r') as f:
+                text = f.read()
+        else:
+            text = ''
+        super(Config, self).__setattr__('_text', text)
+
+    @property
+    def filename(self):
+        return self._filename
+
+    @property
+    def text(self):
+        return self._text
+
+    @property
+    def pretty_text(self):
+
+        indent = 4
+
+        def _indent(s_, num_spaces):
+            s = s_.split('\n')
+            if len(s) == 1:
+                return s_
+            first = s.pop(0)
+            s = [(num_spaces * ' ') + line for line in s]
+            s = '\n'.join(s)
+            s = first + '\n' + s
+            return s
+
+        def _format_basic_types(k, v, use_mapping=False):
+            if isinstance(v, str):
+                v_str = f"'{v}'"
+            else:
+                v_str = str(v)
+
+            if use_mapping:
+                k_str = f"'{k}'" if isinstance(k, str) else str(k)
+                attr_str = f'{k_str}: {v_str}'
+            else:
+                attr_str = f'{str(k)}={v_str}'
+            attr_str = _indent(attr_str, indent)
+
+            return attr_str
+
+        def _format_list(k, v, use_mapping=False):
+            # check if all items in the list are dict
+            if all(isinstance(_, dict) for _ in v):
+                v_str = '[\n'
+                v_str += '\n'.join(
+                    f'dict({_indent(_format_dict(v_), indent)}),'
+                    for v_ in v).rstrip(',')
+                if use_mapping:
+                    k_str = f"'{k}'" if isinstance(k, str) else str(k)
+                    attr_str = f'{k_str}: {v_str}'
+                else:
+                    attr_str = f'{str(k)}={v_str}'
+                attr_str = _indent(attr_str, indent) + ']'
+            else:
+                attr_str = _format_basic_types(k, v, use_mapping)
+            return attr_str
+
+        def _contain_invalid_identifier(dict_str):
+            contain_invalid_identifier = False
+            for key_name in dict_str:
+                contain_invalid_identifier |= \
+                    (not str(key_name).isidentifier())
+            return contain_invalid_identifier
+
+        def _format_dict(input_dict, outest_level=False):
+            r = ''
+            s = []
+
+            use_mapping = _contain_invalid_identifier(input_dict)
+            if use_mapping:
+                r += '{'
+            for idx, (k, v) in enumerate(input_dict.items()):
+                is_last = idx >= len(input_dict) - 1
+                end = '' if outest_level or is_last else ','
+                if isinstance(v, dict):
+                    v_str = '\n' + _format_dict(v)
+                    if use_mapping:
+                        k_str = f"'{k}'" if isinstance(k, str) else str(k)
+                        attr_str = f'{k_str}: dict({v_str}'
+                    else:
+                        attr_str = f'{str(k)}=dict({v_str}'
+                    attr_str = _indent(attr_str, indent) + ')' + end
+                elif isinstance(v, list):
+                    attr_str = _format_list(k, v, use_mapping) + end
+                else:
+                    attr_str = _format_basic_types(k, v, use_mapping) + end
+
+                s.append(attr_str)
+            r += '\n'.join(s)
+            if use_mapping:
+                r += '}'
+            return r
+
+        cfg_dict = self._cfg_dict.to_dict()
+        text = _format_dict(cfg_dict, outest_level=True)
+        # copied from setup.cfg
+        yapf_style = dict(
+            based_on_style='pep8',
+            blank_line_before_nested_class_or_def=True,
+            split_before_expression_after_opening_paren=True)
+        text, _ = FormatCode(text, style_config=yapf_style, verify=True)
+
+        return text
+
+    def __repr__(self):
+        return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
+
+    def __len__(self):
+        return len(self._cfg_dict)
+
+    def __getattr__(self, name):
+        return getattr(self._cfg_dict, name)
+
+    def __getitem__(self, name):
+        return self._cfg_dict.__getitem__(name)
+
+    def __setattr__(self, name, value):
+        if isinstance(value, dict):
+            value = ConfigDict(value)
+        self._cfg_dict.__setattr__(name, value)
+
+    def __setitem__(self, name, value):
+        if isinstance(value, dict):
+            value = ConfigDict(value)
+        self._cfg_dict.__setitem__(name, value)
+
+    def __iter__(self):
+        return iter(self._cfg_dict)
+
+    def __getstate__(self):
+        return (self._cfg_dict, self._filename, self._text)
+
+    def __setstate__(self, state):
+        _cfg_dict, _filename, _text = state
+        super(Config, self).__setattr__('_cfg_dict', _cfg_dict)
+        super(Config, self).__setattr__('_filename', _filename)
+        super(Config, self).__setattr__('_text', _text)
+
+    def dump(self, file=None):
+        cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
+        if self.filename.endswith('.py'):
+            if file is None:
+                return self.pretty_text
+            else:
+                with open(file, 'w', encoding='utf-8') as f:
+                    f.write(self.pretty_text)
+        else:
+            import annotator.uniformer.mmcv as mmcv
+            if file is None:
+                file_format = self.filename.split('.')[-1]
+                return mmcv.dump(cfg_dict, file_format=file_format)
+            else:
+                mmcv.dump(cfg_dict, file)
+
+    def merge_from_dict(self, options, allow_list_keys=True):
+        """Merge list into cfg_dict.
+
+        Merge the dict parsed by MultipleKVAction into this cfg.
+
+        Examples:
+            >>> options = {'model.backbone.depth': 50,
+            ...            'model.backbone.with_cp':True}
+            >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
+            >>> cfg.merge_from_dict(options)
+            >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+            >>> assert cfg_dict == dict(
+            ...     model=dict(backbone=dict(depth=50, with_cp=True)))
+
+            # Merge list element
+            >>> cfg = Config(dict(pipeline=[
+            ...     dict(type='LoadImage'), dict(type='LoadAnnotations')]))
+            >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')})
+            >>> cfg.merge_from_dict(options, allow_list_keys=True)
+            >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+            >>> assert cfg_dict == dict(pipeline=[
+            ...     dict(type='SelfLoadImage'), dict(type='LoadAnnotations')])
+
+        Args:
+            options (dict): dict of configs to merge from.
+            allow_list_keys (bool): If True, int string keys (e.g. '0', '1')
+              are allowed in ``options`` and will replace the element of the
+              corresponding index in the config if the config is a list.
+              Default: True.
+        """
+        option_cfg_dict = {}
+        for full_key, v in options.items():
+            d = option_cfg_dict
+            key_list = full_key.split('.')
+            for subkey in key_list[:-1]:
+                d.setdefault(subkey, ConfigDict())
+                d = d[subkey]
+            subkey = key_list[-1]
+            d[subkey] = v
+
+        cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
+        super(Config, self).__setattr__(
+            '_cfg_dict',
+            Config._merge_a_into_b(
+                option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys))
+
+
+class DictAction(Action):
+    """
+    argparse action to split an argument into KEY=VALUE form
+    on the first = and append to a dictionary. List options can
+    be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
+    brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
+    list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
+    """
+
+    @staticmethod
+    def _parse_int_float_bool(val):
+        try:
+            return int(val)
+        except ValueError:
+            pass
+        try:
+            return float(val)
+        except ValueError:
+            pass
+        if val.lower() in ['true', 'false']:
+            return True if val.lower() == 'true' else False
+        return val
+
+    @staticmethod
+    def _parse_iterable(val):
+        """Parse iterable values in the string.
+
+        All elements inside '()' or '[]' are treated as iterable values.
+
+        Args:
+            val (str): Value string.
+
+        Returns:
+            list | tuple: The expanded list or tuple from the string.
+
+        Examples:
+            >>> DictAction._parse_iterable('1,2,3')
+            [1, 2, 3]
+            >>> DictAction._parse_iterable('[a, b, c]')
+            ['a', 'b', 'c']
+            >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
+            [(1, 2, 3), ['a', 'b'], 'c']
+        """
+
+        def find_next_comma(string):
+            """Find the position of next comma in the string.
+
+            If no ',' is found in the string, return the string length. All
+            chars inside '()' and '[]' are treated as one element and thus ','
+            inside these brackets are ignored.
+            """
+            assert (string.count('(') == string.count(')')) and (
+                    string.count('[') == string.count(']')), \
+                f'Imbalanced brackets exist in {string}'
+            end = len(string)
+            for idx, char in enumerate(string):
+                pre = string[:idx]
+                # The string before this ',' is balanced
+                if ((char == ',') and (pre.count('(') == pre.count(')'))
+                        and (pre.count('[') == pre.count(']'))):
+                    end = idx
+                    break
+            return end
+
+        # Strip ' and " characters and replace whitespace.
+        val = val.strip('\'\"').replace(' ', '')
+        is_tuple = False
+        if val.startswith('(') and val.endswith(')'):
+            is_tuple = True
+            val = val[1:-1]
+        elif val.startswith('[') and val.endswith(']'):
+            val = val[1:-1]
+        elif ',' not in val:
+            # val is a single value
+            return DictAction._parse_int_float_bool(val)
+
+        values = []
+        while len(val) > 0:
+            comma_idx = find_next_comma(val)
+            element = DictAction._parse_iterable(val[:comma_idx])
+            values.append(element)
+            val = val[comma_idx + 1:]
+        if is_tuple:
+            values = tuple(values)
+        return values
+
+    def __call__(self, parser, namespace, values, option_string=None):
+        options = {}
+        for kv in values:
+            key, val = kv.split('=', maxsplit=1)
+            options[key] = self._parse_iterable(val)
+        setattr(namespace, self.dest, options)
diff --git a/annotator/uniformer/mmcv/utils/env.py b/annotator/uniformer/mmcv/utils/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f0d92529e193e6d8339419bcd9bed7901a7769
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/env.py
@@ -0,0 +1,95 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""This file holding some environment constant for sharing by other files."""
+
+import os.path as osp
+import subprocess
+import sys
+from collections import defaultdict
+
+import cv2
+import torch
+
+import annotator.uniformer.mmcv as mmcv
+from .parrots_wrapper import get_build_config
+
+
+def collect_env():
+    """Collect the information of the running environments.
+
+    Returns:
+        dict: The environment information. The following fields are contained.
+
+            - sys.platform: The variable of ``sys.platform``.
+            - Python: Python version.
+            - CUDA available: Bool, indicating if CUDA is available.
+            - GPU devices: Device type of each GPU.
+            - CUDA_HOME (optional): The env var ``CUDA_HOME``.
+            - NVCC (optional): NVCC version.
+            - GCC: GCC version, "n/a" if GCC is not installed.
+            - PyTorch: PyTorch version.
+            - PyTorch compiling details: The output of \
+                ``torch.__config__.show()``.
+            - TorchVision (optional): TorchVision version.
+            - OpenCV: OpenCV version.
+            - MMCV: MMCV version.
+            - MMCV Compiler: The GCC version for compiling MMCV ops.
+            - MMCV CUDA Compiler: The CUDA version for compiling MMCV ops.
+    """
+    env_info = {}
+    env_info['sys.platform'] = sys.platform
+    env_info['Python'] = sys.version.replace('\n', '')
+
+    cuda_available = torch.cuda.is_available()
+    env_info['CUDA available'] = cuda_available
+
+    if cuda_available:
+        devices = defaultdict(list)
+        for k in range(torch.cuda.device_count()):
+            devices[torch.cuda.get_device_name(k)].append(str(k))
+        for name, device_ids in devices.items():
+            env_info['GPU ' + ','.join(device_ids)] = name
+
+        from annotator.uniformer.mmcv.utils.parrots_wrapper import _get_cuda_home
+        CUDA_HOME = _get_cuda_home()
+        env_info['CUDA_HOME'] = CUDA_HOME
+
+        if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
+            try:
+                nvcc = osp.join(CUDA_HOME, 'bin/nvcc')
+                nvcc = subprocess.check_output(
+                    f'"{nvcc}" -V | tail -n1', shell=True)
+                nvcc = nvcc.decode('utf-8').strip()
+            except subprocess.SubprocessError:
+                nvcc = 'Not Available'
+            env_info['NVCC'] = nvcc
+
+    try:
+        gcc = subprocess.check_output('gcc --version | head -n1', shell=True)
+        gcc = gcc.decode('utf-8').strip()
+        env_info['GCC'] = gcc
+    except subprocess.CalledProcessError:  # gcc is unavailable
+        env_info['GCC'] = 'n/a'
+
+    env_info['PyTorch'] = torch.__version__
+    env_info['PyTorch compiling details'] = get_build_config()
+
+    try:
+        import torchvision
+        env_info['TorchVision'] = torchvision.__version__
+    except ModuleNotFoundError:
+        pass
+
+    env_info['OpenCV'] = cv2.__version__
+
+    env_info['MMCV'] = mmcv.__version__
+
+    try:
+        from annotator.uniformer.mmcv.ops import get_compiler_version, get_compiling_cuda_version
+    except ModuleNotFoundError:
+        env_info['MMCV Compiler'] = 'n/a'
+        env_info['MMCV CUDA Compiler'] = 'n/a'
+    else:
+        env_info['MMCV Compiler'] = get_compiler_version()
+        env_info['MMCV CUDA Compiler'] = get_compiling_cuda_version()
+
+    return env_info
diff --git a/annotator/uniformer/mmcv/utils/ext_loader.py b/annotator/uniformer/mmcv/utils/ext_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..08132d2c1b9a1c28880e4bab4d4fa1ba39d9d083
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/ext_loader.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import importlib
+import os
+import pkgutil
+import warnings
+from collections import namedtuple
+
+import torch
+
+if torch.__version__ != 'parrots':
+
+    def load_ext(name, funcs):
+        ext = importlib.import_module('mmcv.' + name)
+        for fun in funcs:
+            assert hasattr(ext, fun), f'{fun} miss in module {name}'
+        return ext
+else:
+    from parrots import extension
+    from parrots.base import ParrotsException
+
+    has_return_value_ops = [
+        'nms',
+        'softnms',
+        'nms_match',
+        'nms_rotated',
+        'top_pool_forward',
+        'top_pool_backward',
+        'bottom_pool_forward',
+        'bottom_pool_backward',
+        'left_pool_forward',
+        'left_pool_backward',
+        'right_pool_forward',
+        'right_pool_backward',
+        'fused_bias_leakyrelu',
+        'upfirdn2d',
+        'ms_deform_attn_forward',
+        'pixel_group',
+        'contour_expand',
+    ]
+
+    def get_fake_func(name, e):
+
+        def fake_func(*args, **kwargs):
+            warnings.warn(f'{name} is not supported in parrots now')
+            raise e
+
+        return fake_func
+
+    def load_ext(name, funcs):
+        ExtModule = namedtuple('ExtModule', funcs)
+        ext_list = []
+        lib_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+        for fun in funcs:
+            try:
+                ext_fun = extension.load(fun, name, lib_dir=lib_root)
+            except ParrotsException as e:
+                if 'No element registered' not in e.message:
+                    warnings.warn(e.message)
+                ext_fun = get_fake_func(fun, e)
+                ext_list.append(ext_fun)
+            else:
+                if fun in has_return_value_ops:
+                    ext_list.append(ext_fun.op)
+                else:
+                    ext_list.append(ext_fun.op_)
+        return ExtModule(*ext_list)
+
+
+def check_ops_exist():
+    ext_loader = pkgutil.find_loader('mmcv._ext')
+    return ext_loader is not None
diff --git a/annotator/uniformer/mmcv/utils/logging.py b/annotator/uniformer/mmcv/utils/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa0e04bb9b3ab2a4bfbc4def50404ccbac2c6e6
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/logging.py
@@ -0,0 +1,110 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import logging
+
+import torch.distributed as dist
+
+logger_initialized = {}
+
+
+def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
+    """Initialize and get a logger by name.
+
+    If the logger has not been initialized, this method will initialize the
+    logger by adding one or two handlers, otherwise the initialized logger will
+    be directly returned. During initialization, a StreamHandler will always be
+    added. If `log_file` is specified and the process rank is 0, a FileHandler
+    will also be added.
+
+    Args:
+        name (str): Logger name.
+        log_file (str | None): The log filename. If specified, a FileHandler
+            will be added to the logger.
+        log_level (int): The logger level. Note that only the process of
+            rank 0 is affected, and other processes will set the level to
+            "Error" thus be silent most of the time.
+        file_mode (str): The file mode used in opening log file.
+            Defaults to 'w'.
+
+    Returns:
+        logging.Logger: The expected logger.
+    """
+    logger = logging.getLogger(name)
+    if name in logger_initialized:
+        return logger
+    # handle hierarchical names
+    # e.g., logger "a" is initialized, then logger "a.b" will skip the
+    # initialization since it is a child of "a".
+    for logger_name in logger_initialized:
+        if name.startswith(logger_name):
+            return logger
+
+    # handle duplicate logs to the console
+    # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
+    # to the root logger. As logger.propagate is True by default, this root
+    # level handler causes logging messages from rank>0 processes to
+    # unexpectedly show up on the console, creating much unwanted clutter.
+    # To fix this issue, we set the root logger's StreamHandler, if any, to log
+    # at the ERROR level.
+    for handler in logger.root.handlers:
+        if type(handler) is logging.StreamHandler:
+            handler.setLevel(logging.ERROR)
+
+    stream_handler = logging.StreamHandler()
+    handlers = [stream_handler]
+
+    if dist.is_available() and dist.is_initialized():
+        rank = dist.get_rank()
+    else:
+        rank = 0
+
+    # only rank 0 will add a FileHandler
+    if rank == 0 and log_file is not None:
+        # Here, the default behaviour of the official logger is 'a'. Thus, we
+        # provide an interface to change the file mode to the default
+        # behaviour.
+        file_handler = logging.FileHandler(log_file, file_mode)
+        handlers.append(file_handler)
+
+    formatter = logging.Formatter(
+        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+    for handler in handlers:
+        handler.setFormatter(formatter)
+        handler.setLevel(log_level)
+        logger.addHandler(handler)
+
+    if rank == 0:
+        logger.setLevel(log_level)
+    else:
+        logger.setLevel(logging.ERROR)
+
+    logger_initialized[name] = True
+
+    return logger
+
+
+def print_log(msg, logger=None, level=logging.INFO):
+    """Print a log message.
+
+    Args:
+        msg (str): The message to be logged.
+        logger (logging.Logger | str | None): The logger to be used.
+            Some special loggers are:
+            - "silent": no message will be printed.
+            - other str: the logger obtained with `get_root_logger(logger)`.
+            - None: The `print()` method will be used to print log messages.
+        level (int): Logging level. Only available when `logger` is a Logger
+            object or "root".
+    """
+    if logger is None:
+        print(msg)
+    elif isinstance(logger, logging.Logger):
+        logger.log(level, msg)
+    elif logger == 'silent':
+        pass
+    elif isinstance(logger, str):
+        _logger = get_logger(logger)
+        _logger.log(level, msg)
+    else:
+        raise TypeError(
+            'logger should be either a logging.Logger object, str, '
+            f'"silent" or None, but got {type(logger)}')
diff --git a/annotator/uniformer/mmcv/utils/misc.py b/annotator/uniformer/mmcv/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c58d0d7fee9fe3d4519270ad8c1e998d0d8a18c
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/misc.py
@@ -0,0 +1,377 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import collections.abc
+import functools
+import itertools
+import subprocess
+import warnings
+from collections import abc
+from importlib import import_module
+from inspect import getfullargspec
+from itertools import repeat
+
+
+# From PyTorch internals
+def _ntuple(n):
+
+    def parse(x):
+        if isinstance(x, collections.abc.Iterable):
+            return x
+        return tuple(repeat(x, n))
+
+    return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def is_str(x):
+    """Whether the input is an string instance.
+
+    Note: This method is deprecated since python 2 is no longer supported.
+    """
+    return isinstance(x, str)
+
+
+def import_modules_from_strings(imports, allow_failed_imports=False):
+    """Import modules from the given list of strings.
+
+    Args:
+        imports (list | str | None): The given module names to be imported.
+        allow_failed_imports (bool): If True, the failed imports will return
+            None. Otherwise, an ImportError is raise. Default: False.
+
+    Returns:
+        list[module] | module | None: The imported modules.
+
+    Examples:
+        >>> osp, sys = import_modules_from_strings(
+        ...     ['os.path', 'sys'])
+        >>> import os.path as osp_
+        >>> import sys as sys_
+        >>> assert osp == osp_
+        >>> assert sys == sys_
+    """
+    if not imports:
+        return
+    single_import = False
+    if isinstance(imports, str):
+        single_import = True
+        imports = [imports]
+    if not isinstance(imports, list):
+        raise TypeError(
+            f'custom_imports must be a list but got type {type(imports)}')
+    imported = []
+    for imp in imports:
+        if not isinstance(imp, str):
+            raise TypeError(
+                f'{imp} is of type {type(imp)} and cannot be imported.')
+        try:
+            imported_tmp = import_module(imp)
+        except ImportError:
+            if allow_failed_imports:
+                warnings.warn(f'{imp} failed to import and is ignored.',
+                              UserWarning)
+                imported_tmp = None
+            else:
+                raise ImportError
+        imported.append(imported_tmp)
+    if single_import:
+        imported = imported[0]
+    return imported
+
+
+def iter_cast(inputs, dst_type, return_type=None):
+    """Cast elements of an iterable object into some type.
+
+    Args:
+        inputs (Iterable): The input object.
+        dst_type (type): Destination type.
+        return_type (type, optional): If specified, the output object will be
+            converted to this type, otherwise an iterator.
+
+    Returns:
+        iterator or specified type: The converted object.
+    """
+    if not isinstance(inputs, abc.Iterable):
+        raise TypeError('inputs must be an iterable object')
+    if not isinstance(dst_type, type):
+        raise TypeError('"dst_type" must be a valid type')
+
+    out_iterable = map(dst_type, inputs)
+
+    if return_type is None:
+        return out_iterable
+    else:
+        return return_type(out_iterable)
+
+
+def list_cast(inputs, dst_type):
+    """Cast elements of an iterable object into a list of some type.
+
+    A partial method of :func:`iter_cast`.
+    """
+    return iter_cast(inputs, dst_type, return_type=list)
+
+
+def tuple_cast(inputs, dst_type):
+    """Cast elements of an iterable object into a tuple of some type.
+
+    A partial method of :func:`iter_cast`.
+    """
+    return iter_cast(inputs, dst_type, return_type=tuple)
+
+
+def is_seq_of(seq, expected_type, seq_type=None):
+    """Check whether it is a sequence of some type.
+
+    Args:
+        seq (Sequence): The sequence to be checked.
+        expected_type (type): Expected type of sequence items.
+        seq_type (type, optional): Expected sequence type.
+
+    Returns:
+        bool: Whether the sequence is valid.
+    """
+    if seq_type is None:
+        exp_seq_type = abc.Sequence
+    else:
+        assert isinstance(seq_type, type)
+        exp_seq_type = seq_type
+    if not isinstance(seq, exp_seq_type):
+        return False
+    for item in seq:
+        if not isinstance(item, expected_type):
+            return False
+    return True
+
+
+def is_list_of(seq, expected_type):
+    """Check whether it is a list of some type.
+
+    A partial method of :func:`is_seq_of`.
+    """
+    return is_seq_of(seq, expected_type, seq_type=list)
+
+
+def is_tuple_of(seq, expected_type):
+    """Check whether it is a tuple of some type.
+
+    A partial method of :func:`is_seq_of`.
+    """
+    return is_seq_of(seq, expected_type, seq_type=tuple)
+
+
+def slice_list(in_list, lens):
+    """Slice a list into several sub lists by a list of given length.
+
+    Args:
+        in_list (list): The list to be sliced.
+        lens(int or list): The expected length of each out list.
+
+    Returns:
+        list: A list of sliced list.
+    """
+    if isinstance(lens, int):
+        assert len(in_list) % lens == 0
+        lens = [lens] * int(len(in_list) / lens)
+    if not isinstance(lens, list):
+        raise TypeError('"indices" must be an integer or a list of integers')
+    elif sum(lens) != len(in_list):
+        raise ValueError('sum of lens and list length does not '
+                         f'match: {sum(lens)} != {len(in_list)}')
+    out_list = []
+    idx = 0
+    for i in range(len(lens)):
+        out_list.append(in_list[idx:idx + lens[i]])
+        idx += lens[i]
+    return out_list
+
+
+def concat_list(in_list):
+    """Concatenate a list of list into a single list.
+
+    Args:
+        in_list (list): The list of list to be merged.
+
+    Returns:
+        list: The concatenated flat list.
+    """
+    return list(itertools.chain(*in_list))
+
+
+def check_prerequisites(
+        prerequisites,
+        checker,
+        msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
+        'found, please install them first.'):  # yapf: disable
+    """A decorator factory to check if prerequisites are satisfied.
+
+    Args:
+        prerequisites (str of list[str]): Prerequisites to be checked.
+        checker (callable): The checker method that returns True if a
+            prerequisite is meet, False otherwise.
+        msg_tmpl (str): The message template with two variables.
+
+    Returns:
+        decorator: A specific decorator.
+    """
+
+    def wrap(func):
+
+        @functools.wraps(func)
+        def wrapped_func(*args, **kwargs):
+            requirements = [prerequisites] if isinstance(
+                prerequisites, str) else prerequisites
+            missing = []
+            for item in requirements:
+                if not checker(item):
+                    missing.append(item)
+            if missing:
+                print(msg_tmpl.format(', '.join(missing), func.__name__))
+                raise RuntimeError('Prerequisites not meet.')
+            else:
+                return func(*args, **kwargs)
+
+        return wrapped_func
+
+    return wrap
+
+
+def _check_py_package(package):
+    try:
+        import_module(package)
+    except ImportError:
+        return False
+    else:
+        return True
+
+
+def _check_executable(cmd):
+    if subprocess.call(f'which {cmd}', shell=True) != 0:
+        return False
+    else:
+        return True
+
+
+def requires_package(prerequisites):
+    """A decorator to check if some python packages are installed.
+
+    Example:
+        >>> @requires_package('numpy')
+        >>> func(arg1, args):
+        >>>     return numpy.zeros(1)
+        array([0.])
+        >>> @requires_package(['numpy', 'non_package'])
+        >>> func(arg1, args):
+        >>>     return numpy.zeros(1)
+        ImportError
+    """
+    return check_prerequisites(prerequisites, checker=_check_py_package)
+
+
+def requires_executable(prerequisites):
+    """A decorator to check if some executable files are installed.
+
+    Example:
+        >>> @requires_executable('ffmpeg')
+        >>> func(arg1, args):
+        >>>     print(1)
+        1
+    """
+    return check_prerequisites(prerequisites, checker=_check_executable)
+
+
+def deprecated_api_warning(name_dict, cls_name=None):
+    """A decorator to check if some arguments are deprecate and try to replace
+    deprecate src_arg_name to dst_arg_name.
+
+    Args:
+        name_dict(dict):
+            key (str): Deprecate argument names.
+            val (str): Expected argument names.
+
+    Returns:
+        func: New function.
+    """
+
+    def api_warning_wrapper(old_func):
+
+        @functools.wraps(old_func)
+        def new_func(*args, **kwargs):
+            # get the arg spec of the decorated method
+            args_info = getfullargspec(old_func)
+            # get name of the function
+            func_name = old_func.__name__
+            if cls_name is not None:
+                func_name = f'{cls_name}.{func_name}'
+            if args:
+                arg_names = args_info.args[:len(args)]
+                for src_arg_name, dst_arg_name in name_dict.items():
+                    if src_arg_name in arg_names:
+                        warnings.warn(
+                            f'"{src_arg_name}" is deprecated in '
+                            f'`{func_name}`, please use "{dst_arg_name}" '
+                            'instead')
+                        arg_names[arg_names.index(src_arg_name)] = dst_arg_name
+            if kwargs:
+                for src_arg_name, dst_arg_name in name_dict.items():
+                    if src_arg_name in kwargs:
+
+                        assert dst_arg_name not in kwargs, (
+                            f'The expected behavior is to replace '
+                            f'the deprecated key `{src_arg_name}` to '
+                            f'new key `{dst_arg_name}`, but got them '
+                            f'in the arguments at the same time, which '
+                            f'is confusing. `{src_arg_name} will be '
+                            f'deprecated in the future, please '
+                            f'use `{dst_arg_name}` instead.')
+
+                        warnings.warn(
+                            f'"{src_arg_name}" is deprecated in '
+                            f'`{func_name}`, please use "{dst_arg_name}" '
+                            'instead')
+                        kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
+
+            # apply converted arguments to the decorated method
+            output = old_func(*args, **kwargs)
+            return output
+
+        return new_func
+
+    return api_warning_wrapper
+
+
+def is_method_overridden(method, base_class, derived_class):
+    """Check if a method of base class is overridden in derived class.
+
+    Args:
+        method (str): the method name to check.
+        base_class (type): the class of the base class.
+        derived_class (type | Any): the class or instance of the derived class.
+    """
+    assert isinstance(base_class, type), \
+        "base_class doesn't accept instance, Please pass class instead."
+
+    if not isinstance(derived_class, type):
+        derived_class = derived_class.__class__
+
+    base_method = getattr(base_class, method)
+    derived_method = getattr(derived_class, method)
+    return derived_method != base_method
+
+
+def has_method(obj: object, method: str) -> bool:
+    """Check whether the object has a method.
+
+    Args:
+        method (str): The method name to check.
+        obj (object): The object to check.
+
+    Returns:
+        bool: True if the object has the method else False.
+    """
+    return hasattr(obj, method) and callable(getattr(obj, method))
diff --git a/annotator/uniformer/mmcv/utils/parrots_jit.py b/annotator/uniformer/mmcv/utils/parrots_jit.py
new file mode 100644
index 0000000000000000000000000000000000000000..61873f6dbb9b10ed972c90aa8faa321e3cb3249e
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/parrots_jit.py
@@ -0,0 +1,41 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+
+from .parrots_wrapper import TORCH_VERSION
+
+parrots_jit_option = os.getenv('PARROTS_JIT_OPTION')
+
+if TORCH_VERSION == 'parrots' and parrots_jit_option == 'ON':
+    from parrots.jit import pat as jit
+else:
+
+    def jit(func=None,
+            check_input=None,
+            full_shape=True,
+            derivate=False,
+            coderize=False,
+            optimize=False):
+
+        def wrapper(func):
+
+            def wrapper_inner(*args, **kargs):
+                return func(*args, **kargs)
+
+            return wrapper_inner
+
+        if func is None:
+            return wrapper
+        else:
+            return func
+
+
+if TORCH_VERSION == 'parrots':
+    from parrots.utils.tester import skip_no_elena
+else:
+
+    def skip_no_elena(func):
+
+        def wrapper(*args, **kargs):
+            return func(*args, **kargs)
+
+        return wrapper
diff --git a/annotator/uniformer/mmcv/utils/parrots_wrapper.py b/annotator/uniformer/mmcv/utils/parrots_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..93c97640d4b9ed088ca82cfe03e6efebfcfa9dbf
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/parrots_wrapper.py
@@ -0,0 +1,107 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+
+import torch
+
+TORCH_VERSION = torch.__version__
+
+
+def is_rocm_pytorch() -> bool:
+    is_rocm = False
+    if TORCH_VERSION != 'parrots':
+        try:
+            from torch.utils.cpp_extension import ROCM_HOME
+            is_rocm = True if ((torch.version.hip is not None) and
+                               (ROCM_HOME is not None)) else False
+        except ImportError:
+            pass
+    return is_rocm
+
+
+def _get_cuda_home():
+    if TORCH_VERSION == 'parrots':
+        from parrots.utils.build_extension import CUDA_HOME
+    else:
+        if is_rocm_pytorch():
+            from torch.utils.cpp_extension import ROCM_HOME
+            CUDA_HOME = ROCM_HOME
+        else:
+            from torch.utils.cpp_extension import CUDA_HOME
+    return CUDA_HOME
+
+
+def get_build_config():
+    if TORCH_VERSION == 'parrots':
+        from parrots.config import get_build_info
+        return get_build_info()
+    else:
+        return torch.__config__.show()
+
+
+def _get_conv():
+    if TORCH_VERSION == 'parrots':
+        from parrots.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+    else:
+        from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+    return _ConvNd, _ConvTransposeMixin
+
+
+def _get_dataloader():
+    if TORCH_VERSION == 'parrots':
+        from torch.utils.data import DataLoader, PoolDataLoader
+    else:
+        from torch.utils.data import DataLoader
+        PoolDataLoader = DataLoader
+    return DataLoader, PoolDataLoader
+
+
+def _get_extension():
+    if TORCH_VERSION == 'parrots':
+        from parrots.utils.build_extension import BuildExtension, Extension
+        CppExtension = partial(Extension, cuda=False)
+        CUDAExtension = partial(Extension, cuda=True)
+    else:
+        from torch.utils.cpp_extension import (BuildExtension, CppExtension,
+                                               CUDAExtension)
+    return BuildExtension, CppExtension, CUDAExtension
+
+
+def _get_pool():
+    if TORCH_VERSION == 'parrots':
+        from parrots.nn.modules.pool import (_AdaptiveAvgPoolNd,
+                                             _AdaptiveMaxPoolNd, _AvgPoolNd,
+                                             _MaxPoolNd)
+    else:
+        from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd,
+                                              _AdaptiveMaxPoolNd, _AvgPoolNd,
+                                              _MaxPoolNd)
+    return _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd
+
+
+def _get_norm():
+    if TORCH_VERSION == 'parrots':
+        from parrots.nn.modules.batchnorm import _BatchNorm, _InstanceNorm
+        SyncBatchNorm_ = torch.nn.SyncBatchNorm2d
+    else:
+        from torch.nn.modules.instancenorm import _InstanceNorm
+        from torch.nn.modules.batchnorm import _BatchNorm
+        SyncBatchNorm_ = torch.nn.SyncBatchNorm
+    return _BatchNorm, _InstanceNorm, SyncBatchNorm_
+
+
+_ConvNd, _ConvTransposeMixin = _get_conv()
+DataLoader, PoolDataLoader = _get_dataloader()
+BuildExtension, CppExtension, CUDAExtension = _get_extension()
+_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
+_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
+
+
+class SyncBatchNorm(SyncBatchNorm_):
+
+    def _check_input_dim(self, input):
+        if TORCH_VERSION == 'parrots':
+            if input.dim() < 2:
+                raise ValueError(
+                    f'expected at least 2D input (got {input.dim()}D input)')
+        else:
+            super()._check_input_dim(input)
diff --git a/annotator/uniformer/mmcv/utils/path.py b/annotator/uniformer/mmcv/utils/path.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dab4b3041413b1432b0f434b8b14783097d33c6
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/path.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+from pathlib import Path
+
+from .misc import is_str
+
+
+def is_filepath(x):
+    return is_str(x) or isinstance(x, Path)
+
+
+def fopen(filepath, *args, **kwargs):
+    if is_str(filepath):
+        return open(filepath, *args, **kwargs)
+    elif isinstance(filepath, Path):
+        return filepath.open(*args, **kwargs)
+    raise ValueError('`filepath` should be a string or a Path')
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+    if not osp.isfile(filename):
+        raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+    if dir_name == '':
+        return
+    dir_name = osp.expanduser(dir_name)
+    os.makedirs(dir_name, mode=mode, exist_ok=True)
+
+
+def symlink(src, dst, overwrite=True, **kwargs):
+    if os.path.lexists(dst) and overwrite:
+        os.remove(dst)
+    os.symlink(src, dst, **kwargs)
+
+
+def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
+    """Scan a directory to find the interested files.
+
+    Args:
+        dir_path (str | obj:`Path`): Path of the directory.
+        suffix (str | tuple(str), optional): File suffix that we are
+            interested in. Default: None.
+        recursive (bool, optional): If set to True, recursively scan the
+            directory. Default: False.
+        case_sensitive (bool, optional) : If set to False, ignore the case of
+            suffix. Default: True.
+
+    Returns:
+        A generator for all the interested files with relative paths.
+    """
+    if isinstance(dir_path, (str, Path)):
+        dir_path = str(dir_path)
+    else:
+        raise TypeError('"dir_path" must be a string or Path object')
+
+    if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+        raise TypeError('"suffix" must be a string or tuple of strings')
+
+    if suffix is not None and not case_sensitive:
+        suffix = suffix.lower() if isinstance(suffix, str) else tuple(
+            item.lower() for item in suffix)
+
+    root = dir_path
+
+    def _scandir(dir_path, suffix, recursive, case_sensitive):
+        for entry in os.scandir(dir_path):
+            if not entry.name.startswith('.') and entry.is_file():
+                rel_path = osp.relpath(entry.path, root)
+                _rel_path = rel_path if case_sensitive else rel_path.lower()
+                if suffix is None or _rel_path.endswith(suffix):
+                    yield rel_path
+            elif recursive and os.path.isdir(entry.path):
+                # scan recursively if entry.path is a directory
+                yield from _scandir(entry.path, suffix, recursive,
+                                    case_sensitive)
+
+    return _scandir(dir_path, suffix, recursive, case_sensitive)
+
+
+def find_vcs_root(path, markers=('.git', )):
+    """Finds the root directory (including itself) of specified markers.
+
+    Args:
+        path (str): Path of directory or file.
+        markers (list[str], optional): List of file or directory names.
+
+    Returns:
+        The directory contained one of the markers or None if not found.
+    """
+    if osp.isfile(path):
+        path = osp.dirname(path)
+
+    prev, cur = None, osp.abspath(osp.expanduser(path))
+    while cur != prev:
+        if any(osp.exists(osp.join(cur, marker)) for marker in markers):
+            return cur
+        prev, cur = cur, osp.split(cur)[0]
+    return None
diff --git a/annotator/uniformer/mmcv/utils/progressbar.py b/annotator/uniformer/mmcv/utils/progressbar.py
new file mode 100644
index 0000000000000000000000000000000000000000..0062f670dd94fa9da559ab26ef85517dcf5211c7
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/progressbar.py
@@ -0,0 +1,208 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import sys
+from collections.abc import Iterable
+from multiprocessing import Pool
+from shutil import get_terminal_size
+
+from .timer import Timer
+
+
+class ProgressBar:
+    """A progress bar which can print the progress."""
+
+    def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
+        self.task_num = task_num
+        self.bar_width = bar_width
+        self.completed = 0
+        self.file = file
+        if start:
+            self.start()
+
+    @property
+    def terminal_width(self):
+        width, _ = get_terminal_size()
+        return width
+
+    def start(self):
+        if self.task_num > 0:
+            self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
+                            'elapsed: 0s, ETA:')
+        else:
+            self.file.write('completed: 0, elapsed: 0s')
+        self.file.flush()
+        self.timer = Timer()
+
+    def update(self, num_tasks=1):
+        assert num_tasks > 0
+        self.completed += num_tasks
+        elapsed = self.timer.since_start()
+        if elapsed > 0:
+            fps = self.completed / elapsed
+        else:
+            fps = float('inf')
+        if self.task_num > 0:
+            percentage = self.completed / float(self.task_num)
+            eta = int(elapsed * (1 - percentage) / percentage + 0.5)
+            msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
+                  f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
+                  f'ETA: {eta:5}s'
+
+            bar_width = min(self.bar_width,
+                            int(self.terminal_width - len(msg)) + 2,
+                            int(self.terminal_width * 0.6))
+            bar_width = max(2, bar_width)
+            mark_width = int(bar_width * percentage)
+            bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
+            self.file.write(msg.format(bar_chars))
+        else:
+            self.file.write(
+                f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
+                f' {fps:.1f} tasks/s')
+        self.file.flush()
+
+
+def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
+    """Track the progress of tasks execution with a progress bar.
+
+    Tasks are done with a simple for-loop.
+
+    Args:
+        func (callable): The function to be applied to each task.
+        tasks (list or tuple[Iterable, int]): A list of tasks or
+            (tasks, total num).
+        bar_width (int): Width of progress bar.
+
+    Returns:
+        list: The task results.
+    """
+    if isinstance(tasks, tuple):
+        assert len(tasks) == 2
+        assert isinstance(tasks[0], Iterable)
+        assert isinstance(tasks[1], int)
+        task_num = tasks[1]
+        tasks = tasks[0]
+    elif isinstance(tasks, Iterable):
+        task_num = len(tasks)
+    else:
+        raise TypeError(
+            '"tasks" must be an iterable object or a (iterator, int) tuple')
+    prog_bar = ProgressBar(task_num, bar_width, file=file)
+    results = []
+    for task in tasks:
+        results.append(func(task, **kwargs))
+        prog_bar.update()
+    prog_bar.file.write('\n')
+    return results
+
+
+def init_pool(process_num, initializer=None, initargs=None):
+    if initializer is None:
+        return Pool(process_num)
+    elif initargs is None:
+        return Pool(process_num, initializer)
+    else:
+        if not isinstance(initargs, tuple):
+            raise TypeError('"initargs" must be a tuple')
+        return Pool(process_num, initializer, initargs)
+
+
+def track_parallel_progress(func,
+                            tasks,
+                            nproc,
+                            initializer=None,
+                            initargs=None,
+                            bar_width=50,
+                            chunksize=1,
+                            skip_first=False,
+                            keep_order=True,
+                            file=sys.stdout):
+    """Track the progress of parallel task execution with a progress bar.
+
+    The built-in :mod:`multiprocessing` module is used for process pools and
+    tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
+
+    Args:
+        func (callable): The function to be applied to each task.
+        tasks (list or tuple[Iterable, int]): A list of tasks or
+            (tasks, total num).
+        nproc (int): Process (worker) number.
+        initializer (None or callable): Refer to :class:`multiprocessing.Pool`
+            for details.
+        initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
+            details.
+        chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
+        bar_width (int): Width of progress bar.
+        skip_first (bool): Whether to skip the first sample for each worker
+            when estimating fps, since the initialization step may takes
+            longer.
+        keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
+            :func:`Pool.imap_unordered` is used.
+
+    Returns:
+        list: The task results.
+    """
+    if isinstance(tasks, tuple):
+        assert len(tasks) == 2
+        assert isinstance(tasks[0], Iterable)
+        assert isinstance(tasks[1], int)
+        task_num = tasks[1]
+        tasks = tasks[0]
+    elif isinstance(tasks, Iterable):
+        task_num = len(tasks)
+    else:
+        raise TypeError(
+            '"tasks" must be an iterable object or a (iterator, int) tuple')
+    pool = init_pool(nproc, initializer, initargs)
+    start = not skip_first
+    task_num -= nproc * chunksize * int(skip_first)
+    prog_bar = ProgressBar(task_num, bar_width, start, file=file)
+    results = []
+    if keep_order:
+        gen = pool.imap(func, tasks, chunksize)
+    else:
+        gen = pool.imap_unordered(func, tasks, chunksize)
+    for result in gen:
+        results.append(result)
+        if skip_first:
+            if len(results) < nproc * chunksize:
+                continue
+            elif len(results) == nproc * chunksize:
+                prog_bar.start()
+                continue
+        prog_bar.update()
+    prog_bar.file.write('\n')
+    pool.close()
+    pool.join()
+    return results
+
+
+def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
+    """Track the progress of tasks iteration or enumeration with a progress
+    bar.
+
+    Tasks are yielded with a simple for-loop.
+
+    Args:
+        tasks (list or tuple[Iterable, int]): A list of tasks or
+            (tasks, total num).
+        bar_width (int): Width of progress bar.
+
+    Yields:
+        list: The task results.
+    """
+    if isinstance(tasks, tuple):
+        assert len(tasks) == 2
+        assert isinstance(tasks[0], Iterable)
+        assert isinstance(tasks[1], int)
+        task_num = tasks[1]
+        tasks = tasks[0]
+    elif isinstance(tasks, Iterable):
+        task_num = len(tasks)
+    else:
+        raise TypeError(
+            '"tasks" must be an iterable object or a (iterator, int) tuple')
+    prog_bar = ProgressBar(task_num, bar_width, file=file)
+    for task in tasks:
+        yield task
+        prog_bar.update()
+    prog_bar.file.write('\n')
diff --git a/annotator/uniformer/mmcv/utils/registry.py b/annotator/uniformer/mmcv/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9df39bc9f3d8d568361e7250ab35468f2b74e0
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/registry.py
@@ -0,0 +1,315 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import inspect
+import warnings
+from functools import partial
+
+from .misc import is_seq_of
+
+
+def build_from_cfg(cfg, registry, default_args=None):
+    """Build a module from config dict.
+
+    Args:
+        cfg (dict): Config dict. It should at least contain the key "type".
+        registry (:obj:`Registry`): The registry to search the type from.
+        default_args (dict, optional): Default initialization arguments.
+
+    Returns:
+        object: The constructed object.
+    """
+    if not isinstance(cfg, dict):
+        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+    if 'type' not in cfg:
+        if default_args is None or 'type' not in default_args:
+            raise KeyError(
+                '`cfg` or `default_args` must contain the key "type", '
+                f'but got {cfg}\n{default_args}')
+    if not isinstance(registry, Registry):
+        raise TypeError('registry must be an mmcv.Registry object, '
+                        f'but got {type(registry)}')
+    if not (isinstance(default_args, dict) or default_args is None):
+        raise TypeError('default_args must be a dict or None, '
+                        f'but got {type(default_args)}')
+
+    args = cfg.copy()
+
+    if default_args is not None:
+        for name, value in default_args.items():
+            args.setdefault(name, value)
+
+    obj_type = args.pop('type')
+    if isinstance(obj_type, str):
+        obj_cls = registry.get(obj_type)
+        if obj_cls is None:
+            raise KeyError(
+                f'{obj_type} is not in the {registry.name} registry')
+    elif inspect.isclass(obj_type):
+        obj_cls = obj_type
+    else:
+        raise TypeError(
+            f'type must be a str or valid type, but got {type(obj_type)}')
+    try:
+        return obj_cls(**args)
+    except Exception as e:
+        # Normal TypeError does not print class name.
+        raise type(e)(f'{obj_cls.__name__}: {e}')
+
+
+class Registry:
+    """A registry to map strings to classes.
+
+    Registered object could be built from registry.
+    Example:
+        >>> MODELS = Registry('models')
+        >>> @MODELS.register_module()
+        >>> class ResNet:
+        >>>     pass
+        >>> resnet = MODELS.build(dict(type='ResNet'))
+
+    Please refer to
+    https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
+    advanced usage.
+
+    Args:
+        name (str): Registry name.
+        build_func(func, optional): Build function to construct instance from
+            Registry, func:`build_from_cfg` is used if neither ``parent`` or
+            ``build_func`` is specified. If ``parent`` is specified and
+            ``build_func`` is not given,  ``build_func`` will be inherited
+            from ``parent``. Default: None.
+        parent (Registry, optional): Parent registry. The class registered in
+            children registry could be built from parent. Default: None.
+        scope (str, optional): The scope of registry. It is the key to search
+            for children registry. If not specified, scope will be the name of
+            the package where class is defined, e.g. mmdet, mmcls, mmseg.
+            Default: None.
+    """
+
+    def __init__(self, name, build_func=None, parent=None, scope=None):
+        self._name = name
+        self._module_dict = dict()
+        self._children = dict()
+        self._scope = self.infer_scope() if scope is None else scope
+
+        # self.build_func will be set with the following priority:
+        # 1. build_func
+        # 2. parent.build_func
+        # 3. build_from_cfg
+        if build_func is None:
+            if parent is not None:
+                self.build_func = parent.build_func
+            else:
+                self.build_func = build_from_cfg
+        else:
+            self.build_func = build_func
+        if parent is not None:
+            assert isinstance(parent, Registry)
+            parent._add_children(self)
+            self.parent = parent
+        else:
+            self.parent = None
+
+    def __len__(self):
+        return len(self._module_dict)
+
+    def __contains__(self, key):
+        return self.get(key) is not None
+
+    def __repr__(self):
+        format_str = self.__class__.__name__ + \
+                     f'(name={self._name}, ' \
+                     f'items={self._module_dict})'
+        return format_str
+
+    @staticmethod
+    def infer_scope():
+        """Infer the scope of registry.
+
+        The name of the package where registry is defined will be returned.
+
+        Example:
+            # in mmdet/models/backbone/resnet.py
+            >>> MODELS = Registry('models')
+            >>> @MODELS.register_module()
+            >>> class ResNet:
+            >>>     pass
+            The scope of ``ResNet`` will be ``mmdet``.
+
+
+        Returns:
+            scope (str): The inferred scope name.
+        """
+        # inspect.stack() trace where this function is called, the index-2
+        # indicates the frame where `infer_scope()` is called
+        filename = inspect.getmodule(inspect.stack()[2][0]).__name__
+        split_filename = filename.split('.')
+        return split_filename[0]
+
+    @staticmethod
+    def split_scope_key(key):
+        """Split scope and key.
+
+        The first scope will be split from key.
+
+        Examples:
+            >>> Registry.split_scope_key('mmdet.ResNet')
+            'mmdet', 'ResNet'
+            >>> Registry.split_scope_key('ResNet')
+            None, 'ResNet'
+
+        Return:
+            scope (str, None): The first scope.
+            key (str): The remaining key.
+        """
+        split_index = key.find('.')
+        if split_index != -1:
+            return key[:split_index], key[split_index + 1:]
+        else:
+            return None, key
+
+    @property
+    def name(self):
+        return self._name
+
+    @property
+    def scope(self):
+        return self._scope
+
+    @property
+    def module_dict(self):
+        return self._module_dict
+
+    @property
+    def children(self):
+        return self._children
+
+    def get(self, key):
+        """Get the registry record.
+
+        Args:
+            key (str): The class name in string format.
+
+        Returns:
+            class: The corresponding class.
+        """
+        scope, real_key = self.split_scope_key(key)
+        if scope is None or scope == self._scope:
+            # get from self
+            if real_key in self._module_dict:
+                return self._module_dict[real_key]
+        else:
+            # get from self._children
+            if scope in self._children:
+                return self._children[scope].get(real_key)
+            else:
+                # goto root
+                parent = self.parent
+                while parent.parent is not None:
+                    parent = parent.parent
+                return parent.get(key)
+
+    def build(self, *args, **kwargs):
+        return self.build_func(*args, **kwargs, registry=self)
+
+    def _add_children(self, registry):
+        """Add children for a registry.
+
+        The ``registry`` will be added as children based on its scope.
+        The parent registry could build objects from children registry.
+
+        Example:
+            >>> models = Registry('models')
+            >>> mmdet_models = Registry('models', parent=models)
+            >>> @mmdet_models.register_module()
+            >>> class ResNet:
+            >>>     pass
+            >>> resnet = models.build(dict(type='mmdet.ResNet'))
+        """
+
+        assert isinstance(registry, Registry)
+        assert registry.scope is not None
+        assert registry.scope not in self.children, \
+            f'scope {registry.scope} exists in {self.name} registry'
+        self.children[registry.scope] = registry
+
+    def _register_module(self, module_class, module_name=None, force=False):
+        if not inspect.isclass(module_class):
+            raise TypeError('module must be a class, '
+                            f'but got {type(module_class)}')
+
+        if module_name is None:
+            module_name = module_class.__name__
+        if isinstance(module_name, str):
+            module_name = [module_name]
+        for name in module_name:
+            if not force and name in self._module_dict:
+                raise KeyError(f'{name} is already registered '
+                               f'in {self.name}')
+            self._module_dict[name] = module_class
+
+    def deprecated_register_module(self, cls=None, force=False):
+        warnings.warn(
+            'The old API of register_module(module, force=False) '
+            'is deprecated and will be removed, please use the new API '
+            'register_module(name=None, force=False, module=None) instead.')
+        if cls is None:
+            return partial(self.deprecated_register_module, force=force)
+        self._register_module(cls, force=force)
+        return cls
+
+    def register_module(self, name=None, force=False, module=None):
+        """Register a module.
+
+        A record will be added to `self._module_dict`, whose key is the class
+        name or the specified name, and value is the class itself.
+        It can be used as a decorator or a normal function.
+
+        Example:
+            >>> backbones = Registry('backbone')
+            >>> @backbones.register_module()
+            >>> class ResNet:
+            >>>     pass
+
+            >>> backbones = Registry('backbone')
+            >>> @backbones.register_module(name='mnet')
+            >>> class MobileNet:
+            >>>     pass
+
+            >>> backbones = Registry('backbone')
+            >>> class ResNet:
+            >>>     pass
+            >>> backbones.register_module(ResNet)
+
+        Args:
+            name (str | None): The module name to be registered. If not
+                specified, the class name will be used.
+            force (bool, optional): Whether to override an existing class with
+                the same name. Default: False.
+            module (type): Module class to be registered.
+        """
+        if not isinstance(force, bool):
+            raise TypeError(f'force must be a boolean, but got {type(force)}')
+        # NOTE: This is a walkaround to be compatible with the old api,
+        # while it may introduce unexpected bugs.
+        if isinstance(name, type):
+            return self.deprecated_register_module(name, force=force)
+
+        # raise the error ahead of time
+        if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
+            raise TypeError(
+                'name must be either of None, an instance of str or a sequence'
+                f'  of str, but got {type(name)}')
+
+        # use it as a normal method: x.register_module(module=SomeClass)
+        if module is not None:
+            self._register_module(
+                module_class=module, module_name=name, force=force)
+            return module
+
+        # use it as a decorator: @x.register_module()
+        def _register(cls):
+            self._register_module(
+                module_class=cls, module_name=name, force=force)
+            return cls
+
+        return _register
diff --git a/annotator/uniformer/mmcv/utils/testing.py b/annotator/uniformer/mmcv/utils/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..a27f936da8ec14bac18562ede0a79d476d82f797
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/testing.py
@@ -0,0 +1,140 @@
+# Copyright (c) Open-MMLab.
+import sys
+from collections.abc import Iterable
+from runpy import run_path
+from shlex import split
+from typing import Any, Dict, List
+from unittest.mock import patch
+
+
+def check_python_script(cmd):
+    """Run the python cmd script with `__main__`. The difference between
+    `os.system` is that, this function exectues code in the current process, so
+    that it can be tracked by coverage tools. Currently it supports two forms:
+
+    - ./tests/data/scripts/hello.py zz
+    - python tests/data/scripts/hello.py zz
+    """
+    args = split(cmd)
+    if args[0] == 'python':
+        args = args[1:]
+    with patch.object(sys, 'argv', args):
+        run_path(args[0], run_name='__main__')
+
+
+def _any(judge_result):
+    """Since built-in ``any`` works only when the element of iterable is not
+    iterable, implement the function."""
+    if not isinstance(judge_result, Iterable):
+        return judge_result
+
+    try:
+        for element in judge_result:
+            if _any(element):
+                return True
+    except TypeError:
+        # Maybe encounter the case: torch.tensor(True) | torch.tensor(False)
+        if judge_result:
+            return True
+    return False
+
+
+def assert_dict_contains_subset(dict_obj: Dict[Any, Any],
+                                expected_subset: Dict[Any, Any]) -> bool:
+    """Check if the dict_obj contains the expected_subset.
+
+    Args:
+        dict_obj (Dict[Any, Any]): Dict object to be checked.
+        expected_subset (Dict[Any, Any]): Subset expected to be contained in
+            dict_obj.
+
+    Returns:
+        bool: Whether the dict_obj contains the expected_subset.
+    """
+
+    for key, value in expected_subset.items():
+        if key not in dict_obj.keys() or _any(dict_obj[key] != value):
+            return False
+    return True
+
+
+def assert_attrs_equal(obj: Any, expected_attrs: Dict[str, Any]) -> bool:
+    """Check if attribute of class object is correct.
+
+    Args:
+        obj (object): Class object to be checked.
+        expected_attrs (Dict[str, Any]): Dict of the expected attrs.
+
+    Returns:
+        bool: Whether the attribute of class object is correct.
+    """
+    for attr, value in expected_attrs.items():
+        if not hasattr(obj, attr) or _any(getattr(obj, attr) != value):
+            return False
+    return True
+
+
+def assert_dict_has_keys(obj: Dict[str, Any],
+                         expected_keys: List[str]) -> bool:
+    """Check if the obj has all the expected_keys.
+
+    Args:
+        obj (Dict[str, Any]): Object to be checked.
+        expected_keys (List[str]): Keys expected to contained in the keys of
+            the obj.
+
+    Returns:
+        bool: Whether the obj has the expected keys.
+    """
+    return set(expected_keys).issubset(set(obj.keys()))
+
+
+def assert_keys_equal(result_keys: List[str], target_keys: List[str]) -> bool:
+    """Check if target_keys is equal to result_keys.
+
+    Args:
+        result_keys (List[str]): Result keys to be checked.
+        target_keys (List[str]): Target keys to be checked.
+
+    Returns:
+        bool: Whether target_keys is equal to result_keys.
+    """
+    return set(result_keys) == set(target_keys)
+
+
+def assert_is_norm_layer(module) -> bool:
+    """Check if the module is a norm layer.
+
+    Args:
+        module (nn.Module): The module to be checked.
+
+    Returns:
+        bool: Whether the module is a norm layer.
+    """
+    from .parrots_wrapper import _BatchNorm, _InstanceNorm
+    from torch.nn import GroupNorm, LayerNorm
+    norm_layer_candidates = (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)
+    return isinstance(module, norm_layer_candidates)
+
+
+def assert_params_all_zeros(module) -> bool:
+    """Check if the parameters of the module is all zeros.
+
+    Args:
+        module (nn.Module): The module to be checked.
+
+    Returns:
+        bool: Whether the parameters of the module is all zeros.
+    """
+    weight_data = module.weight.data
+    is_weight_zero = weight_data.allclose(
+        weight_data.new_zeros(weight_data.size()))
+
+    if hasattr(module, 'bias') and module.bias is not None:
+        bias_data = module.bias.data
+        is_bias_zero = bias_data.allclose(
+            bias_data.new_zeros(bias_data.size()))
+    else:
+        is_bias_zero = True
+
+    return is_weight_zero and is_bias_zero
diff --git a/annotator/uniformer/mmcv/utils/timer.py b/annotator/uniformer/mmcv/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3db7d497d8b374e18b5297e0a1d6eb186fd8cba
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/timer.py
@@ -0,0 +1,118 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from time import time
+
+
+class TimerError(Exception):
+
+    def __init__(self, message):
+        self.message = message
+        super(TimerError, self).__init__(message)
+
+
+class Timer:
+    """A flexible Timer class.
+
+    :Example:
+
+    >>> import time
+    >>> import annotator.uniformer.mmcv as mmcv
+    >>> with mmcv.Timer():
+    >>>     # simulate a code block that will run for 1s
+    >>>     time.sleep(1)
+    1.000
+    >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
+    >>>     # simulate a code block that will run for 1s
+    >>>     time.sleep(1)
+    it takes 1.0 seconds
+    >>> timer = mmcv.Timer()
+    >>> time.sleep(0.5)
+    >>> print(timer.since_start())
+    0.500
+    >>> time.sleep(0.5)
+    >>> print(timer.since_last_check())
+    0.500
+    >>> print(timer.since_start())
+    1.000
+    """
+
+    def __init__(self, start=True, print_tmpl=None):
+        self._is_running = False
+        self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
+        if start:
+            self.start()
+
+    @property
+    def is_running(self):
+        """bool: indicate whether the timer is running"""
+        return self._is_running
+
+    def __enter__(self):
+        self.start()
+        return self
+
+    def __exit__(self, type, value, traceback):
+        print(self.print_tmpl.format(self.since_last_check()))
+        self._is_running = False
+
+    def start(self):
+        """Start the timer."""
+        if not self._is_running:
+            self._t_start = time()
+            self._is_running = True
+        self._t_last = time()
+
+    def since_start(self):
+        """Total time since the timer is started.
+
+        Returns (float): Time in seconds.
+        """
+        if not self._is_running:
+            raise TimerError('timer is not running')
+        self._t_last = time()
+        return self._t_last - self._t_start
+
+    def since_last_check(self):
+        """Time since the last checking.
+
+        Either :func:`since_start` or :func:`since_last_check` is a checking
+        operation.
+
+        Returns (float): Time in seconds.
+        """
+        if not self._is_running:
+            raise TimerError('timer is not running')
+        dur = time() - self._t_last
+        self._t_last = time()
+        return dur
+
+
+_g_timers = {}  # global timers
+
+
+def check_time(timer_id):
+    """Add check points in a single line.
+
+    This method is suitable for running a task on a list of items. A timer will
+    be registered when the method is called for the first time.
+
+    :Example:
+
+    >>> import time
+    >>> import annotator.uniformer.mmcv as mmcv
+    >>> for i in range(1, 6):
+    >>>     # simulate a code block
+    >>>     time.sleep(i)
+    >>>     mmcv.check_time('task1')
+    2.000
+    3.000
+    4.000
+    5.000
+
+    Args:
+        timer_id (str): Timer identifier.
+    """
+    if timer_id not in _g_timers:
+        _g_timers[timer_id] = Timer()
+        return 0
+    else:
+        return _g_timers[timer_id].since_last_check()
diff --git a/annotator/uniformer/mmcv/utils/trace.py b/annotator/uniformer/mmcv/utils/trace.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ca99dc3eda05ef980d9a4249b50deca8273b6cc
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/trace.py
@@ -0,0 +1,23 @@
+import warnings
+
+import torch
+
+from annotator.uniformer.mmcv.utils import digit_version
+
+
+def is_jit_tracing() -> bool:
+    if (torch.__version__ != 'parrots'
+            and digit_version(torch.__version__) >= digit_version('1.6.0')):
+        on_trace = torch.jit.is_tracing()
+        # In PyTorch 1.6, torch.jit.is_tracing has a bug.
+        # Refers to https://github.com/pytorch/pytorch/issues/42448
+        if isinstance(on_trace, bool):
+            return on_trace
+        else:
+            return torch._C._is_tracing()
+    else:
+        warnings.warn(
+            'torch.jit.is_tracing is only supported after v1.6.0. '
+            'Therefore is_tracing returns False automatically. Please '
+            'set on_trace manually if you are using trace.', UserWarning)
+        return False
diff --git a/annotator/uniformer/mmcv/utils/version_utils.py b/annotator/uniformer/mmcv/utils/version_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..963c45a2e8a86a88413ab6c18c22481fb9831985
--- /dev/null
+++ b/annotator/uniformer/mmcv/utils/version_utils.py
@@ -0,0 +1,90 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import subprocess
+import warnings
+
+from packaging.version import parse
+
+
+def digit_version(version_str: str, length: int = 4):
+    """Convert a version string into a tuple of integers.
+
+    This method is usually used for comparing two versions. For pre-release
+    versions: alpha < beta < rc.
+
+    Args:
+        version_str (str): The version string.
+        length (int): The maximum number of version levels. Default: 4.
+
+    Returns:
+        tuple[int]: The version info in digits (integers).
+    """
+    assert 'parrots' not in version_str
+    version = parse(version_str)
+    assert version.release, f'failed to parse version {version_str}'
+    release = list(version.release)
+    release = release[:length]
+    if len(release) < length:
+        release = release + [0] * (length - len(release))
+    if version.is_prerelease:
+        mapping = {'a': -3, 'b': -2, 'rc': -1}
+        val = -4
+        # version.pre can be None
+        if version.pre:
+            if version.pre[0] not in mapping:
+                warnings.warn(f'unknown prerelease version {version.pre[0]}, '
+                              'version checking may go wrong')
+            else:
+                val = mapping[version.pre[0]]
+            release.extend([val, version.pre[-1]])
+        else:
+            release.extend([val, 0])
+
+    elif version.is_postrelease:
+        release.extend([1, version.post])
+    else:
+        release.extend([0, 0])
+    return tuple(release)
+
+
+def _minimal_ext_cmd(cmd):
+    # construct minimal environment
+    env = {}
+    for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+        v = os.environ.get(k)
+        if v is not None:
+            env[k] = v
+    # LANGUAGE is used on win32
+    env['LANGUAGE'] = 'C'
+    env['LANG'] = 'C'
+    env['LC_ALL'] = 'C'
+    out = subprocess.Popen(
+        cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+    return out
+
+
+def get_git_hash(fallback='unknown', digits=None):
+    """Get the git hash of the current repo.
+
+    Args:
+        fallback (str, optional): The fallback string when git hash is
+            unavailable. Defaults to 'unknown'.
+        digits (int, optional): kept digits of the hash. Defaults to None,
+            meaning all digits are kept.
+
+    Returns:
+        str: Git commit hash.
+    """
+
+    if digits is not None and not isinstance(digits, int):
+        raise TypeError('digits must be None or an integer')
+
+    try:
+        out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+        sha = out.strip().decode('ascii')
+        if digits is not None:
+            sha = sha[:digits]
+    except OSError:
+        sha = fallback
+
+    return sha
diff --git a/annotator/uniformer/mmcv/version.py b/annotator/uniformer/mmcv/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cce4e50bd692d4002e3cac3c545a3fb2efe95d0
--- /dev/null
+++ b/annotator/uniformer/mmcv/version.py
@@ -0,0 +1,35 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+__version__ = '1.3.17'
+
+
+def parse_version_info(version_str: str, length: int = 4) -> tuple:
+    """Parse a version string into a tuple.
+
+    Args:
+        version_str (str): The version string.
+        length (int): The maximum number of version levels. Default: 4.
+
+    Returns:
+        tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
+            (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
+            (2, 0, 0, 0, 'rc', 1) (when length is set to 4).
+    """
+    from packaging.version import parse
+    version = parse(version_str)
+    assert version.release, f'failed to parse version {version_str}'
+    release = list(version.release)
+    release = release[:length]
+    if len(release) < length:
+        release = release + [0] * (length - len(release))
+    if version.is_prerelease:
+        release.extend(list(version.pre))
+    elif version.is_postrelease:
+        release.extend(list(version.post))
+    else:
+        release.extend([0, 0])
+    return tuple(release)
+
+
+version_info = tuple(int(x) for x in __version__.split('.')[:3])
+
+__all__ = ['__version__', 'version_info', 'parse_version_info']
diff --git a/annotator/uniformer/mmcv/video/__init__.py b/annotator/uniformer/mmcv/video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73199b01dec52820dc6ca0139903536344d5a1eb
--- /dev/null
+++ b/annotator/uniformer/mmcv/video/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .io import Cache, VideoReader, frames2video
+from .optflow import (dequantize_flow, flow_from_bytes, flow_warp, flowread,
+                      flowwrite, quantize_flow, sparse_flow_from_bytes)
+from .processing import concat_video, convert_video, cut_video, resize_video
+
+__all__ = [
+    'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
+    'cut_video', 'concat_video', 'flowread', 'flowwrite', 'quantize_flow',
+    'dequantize_flow', 'flow_warp', 'flow_from_bytes', 'sparse_flow_from_bytes'
+]
diff --git a/annotator/uniformer/mmcv/video/io.py b/annotator/uniformer/mmcv/video/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..9879154227f640c262853b92c219461c6f67ee8e
--- /dev/null
+++ b/annotator/uniformer/mmcv/video/io.py
@@ -0,0 +1,318 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from collections import OrderedDict
+
+import cv2
+from cv2 import (CAP_PROP_FOURCC, CAP_PROP_FPS, CAP_PROP_FRAME_COUNT,
+                 CAP_PROP_FRAME_HEIGHT, CAP_PROP_FRAME_WIDTH,
+                 CAP_PROP_POS_FRAMES, VideoWriter_fourcc)
+
+from annotator.uniformer.mmcv.utils import (check_file_exist, mkdir_or_exist, scandir,
+                        track_progress)
+
+
+class Cache:
+
+    def __init__(self, capacity):
+        self._cache = OrderedDict()
+        self._capacity = int(capacity)
+        if capacity <= 0:
+            raise ValueError('capacity must be a positive integer')
+
+    @property
+    def capacity(self):
+        return self._capacity
+
+    @property
+    def size(self):
+        return len(self._cache)
+
+    def put(self, key, val):
+        if key in self._cache:
+            return
+        if len(self._cache) >= self.capacity:
+            self._cache.popitem(last=False)
+        self._cache[key] = val
+
+    def get(self, key, default=None):
+        val = self._cache[key] if key in self._cache else default
+        return val
+
+
+class VideoReader:
+    """Video class with similar usage to a list object.
+
+    This video warpper class provides convenient apis to access frames.
+    There exists an issue of OpenCV's VideoCapture class that jumping to a
+    certain frame may be inaccurate. It is fixed in this class by checking
+    the position after jumping each time.
+    Cache is used when decoding videos. So if the same frame is visited for
+    the second time, there is no need to decode again if it is stored in the
+    cache.
+
+    :Example:
+
+    >>> import annotator.uniformer.mmcv as mmcv
+    >>> v = mmcv.VideoReader('sample.mp4')
+    >>> len(v)  # get the total frame number with `len()`
+    120
+    >>> for img in v:  # v is iterable
+    >>>     mmcv.imshow(img)
+    >>> v[5]  # get the 6th frame
+    """
+
+    def __init__(self, filename, cache_capacity=10):
+        # Check whether the video path is a url
+        if not filename.startswith(('https://', 'http://')):
+            check_file_exist(filename, 'Video file not found: ' + filename)
+        self._vcap = cv2.VideoCapture(filename)
+        assert cache_capacity > 0
+        self._cache = Cache(cache_capacity)
+        self._position = 0
+        # get basic info
+        self._width = int(self._vcap.get(CAP_PROP_FRAME_WIDTH))
+        self._height = int(self._vcap.get(CAP_PROP_FRAME_HEIGHT))
+        self._fps = self._vcap.get(CAP_PROP_FPS)
+        self._frame_cnt = int(self._vcap.get(CAP_PROP_FRAME_COUNT))
+        self._fourcc = self._vcap.get(CAP_PROP_FOURCC)
+
+    @property
+    def vcap(self):
+        """:obj:`cv2.VideoCapture`: The raw VideoCapture object."""
+        return self._vcap
+
+    @property
+    def opened(self):
+        """bool: Indicate whether the video is opened."""
+        return self._vcap.isOpened()
+
+    @property
+    def width(self):
+        """int: Width of video frames."""
+        return self._width
+
+    @property
+    def height(self):
+        """int: Height of video frames."""
+        return self._height
+
+    @property
+    def resolution(self):
+        """tuple: Video resolution (width, height)."""
+        return (self._width, self._height)
+
+    @property
+    def fps(self):
+        """float: FPS of the video."""
+        return self._fps
+
+    @property
+    def frame_cnt(self):
+        """int: Total frames of the video."""
+        return self._frame_cnt
+
+    @property
+    def fourcc(self):
+        """str: "Four character code" of the video."""
+        return self._fourcc
+
+    @property
+    def position(self):
+        """int: Current cursor position, indicating frame decoded."""
+        return self._position
+
+    def _get_real_position(self):
+        return int(round(self._vcap.get(CAP_PROP_POS_FRAMES)))
+
+    def _set_real_position(self, frame_id):
+        self._vcap.set(CAP_PROP_POS_FRAMES, frame_id)
+        pos = self._get_real_position()
+        for _ in range(frame_id - pos):
+            self._vcap.read()
+        self._position = frame_id
+
+    def read(self):
+        """Read the next frame.
+
+        If the next frame have been decoded before and in the cache, then
+        return it directly, otherwise decode, cache and return it.
+
+        Returns:
+            ndarray or None: Return the frame if successful, otherwise None.
+        """
+        # pos = self._position
+        if self._cache:
+            img = self._cache.get(self._position)
+            if img is not None:
+                ret = True
+            else:
+                if self._position != self._get_real_position():
+                    self._set_real_position(self._position)
+                ret, img = self._vcap.read()
+                if ret:
+                    self._cache.put(self._position, img)
+        else:
+            ret, img = self._vcap.read()
+        if ret:
+            self._position += 1
+        return img
+
+    def get_frame(self, frame_id):
+        """Get frame by index.
+
+        Args:
+            frame_id (int): Index of the expected frame, 0-based.
+
+        Returns:
+            ndarray or None: Return the frame if successful, otherwise None.
+        """
+        if frame_id < 0 or frame_id >= self._frame_cnt:
+            raise IndexError(
+                f'"frame_id" must be between 0 and {self._frame_cnt - 1}')
+        if frame_id == self._position:
+            return self.read()
+        if self._cache:
+            img = self._cache.get(frame_id)
+            if img is not None:
+                self._position = frame_id + 1
+                return img
+        self._set_real_position(frame_id)
+        ret, img = self._vcap.read()
+        if ret:
+            if self._cache:
+                self._cache.put(self._position, img)
+            self._position += 1
+        return img
+
+    def current_frame(self):
+        """Get the current frame (frame that is just visited).
+
+        Returns:
+            ndarray or None: If the video is fresh, return None, otherwise
+                return the frame.
+        """
+        if self._position == 0:
+            return None
+        return self._cache.get(self._position - 1)
+
+    def cvt2frames(self,
+                   frame_dir,
+                   file_start=0,
+                   filename_tmpl='{:06d}.jpg',
+                   start=0,
+                   max_num=0,
+                   show_progress=True):
+        """Convert a video to frame images.
+
+        Args:
+            frame_dir (str): Output directory to store all the frame images.
+            file_start (int): Filenames will start from the specified number.
+            filename_tmpl (str): Filename template with the index as the
+                placeholder.
+            start (int): The starting frame index.
+            max_num (int): Maximum number of frames to be written.
+            show_progress (bool): Whether to show a progress bar.
+        """
+        mkdir_or_exist(frame_dir)
+        if max_num == 0:
+            task_num = self.frame_cnt - start
+        else:
+            task_num = min(self.frame_cnt - start, max_num)
+        if task_num <= 0:
+            raise ValueError('start must be less than total frame number')
+        if start > 0:
+            self._set_real_position(start)
+
+        def write_frame(file_idx):
+            img = self.read()
+            if img is None:
+                return
+            filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+            cv2.imwrite(filename, img)
+
+        if show_progress:
+            track_progress(write_frame, range(file_start,
+                                              file_start + task_num))
+        else:
+            for i in range(task_num):
+                write_frame(file_start + i)
+
+    def __len__(self):
+        return self.frame_cnt
+
+    def __getitem__(self, index):
+        if isinstance(index, slice):
+            return [
+                self.get_frame(i)
+                for i in range(*index.indices(self.frame_cnt))
+            ]
+        # support negative indexing
+        if index < 0:
+            index += self.frame_cnt
+            if index < 0:
+                raise IndexError('index out of range')
+        return self.get_frame(index)
+
+    def __iter__(self):
+        self._set_real_position(0)
+        return self
+
+    def __next__(self):
+        img = self.read()
+        if img is not None:
+            return img
+        else:
+            raise StopIteration
+
+    next = __next__
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        self._vcap.release()
+
+
+def frames2video(frame_dir,
+                 video_file,
+                 fps=30,
+                 fourcc='XVID',
+                 filename_tmpl='{:06d}.jpg',
+                 start=0,
+                 end=0,
+                 show_progress=True):
+    """Read the frame images from a directory and join them as a video.
+
+    Args:
+        frame_dir (str): The directory containing video frames.
+        video_file (str): Output filename.
+        fps (float): FPS of the output video.
+        fourcc (str): Fourcc of the output video, this should be compatible
+            with the output file type.
+        filename_tmpl (str): Filename template with the index as the variable.
+        start (int): Starting frame index.
+        end (int): Ending frame index.
+        show_progress (bool): Whether to show a progress bar.
+    """
+    if end == 0:
+        ext = filename_tmpl.split('.')[-1]
+        end = len([name for name in scandir(frame_dir, ext)])
+    first_file = osp.join(frame_dir, filename_tmpl.format(start))
+    check_file_exist(first_file, 'The start frame not found: ' + first_file)
+    img = cv2.imread(first_file)
+    height, width = img.shape[:2]
+    resolution = (width, height)
+    vwriter = cv2.VideoWriter(video_file, VideoWriter_fourcc(*fourcc), fps,
+                              resolution)
+
+    def write_frame(file_idx):
+        filename = osp.join(frame_dir, filename_tmpl.format(file_idx))
+        img = cv2.imread(filename)
+        vwriter.write(img)
+
+    if show_progress:
+        track_progress(write_frame, range(start, end))
+    else:
+        for i in range(start, end):
+            write_frame(i)
+    vwriter.release()
diff --git a/annotator/uniformer/mmcv/video/optflow.py b/annotator/uniformer/mmcv/video/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..84160f8d6ef9fceb5a2f89e7481593109fc1905d
--- /dev/null
+++ b/annotator/uniformer/mmcv/video/optflow.py
@@ -0,0 +1,254 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import cv2
+import numpy as np
+
+from annotator.uniformer.mmcv.arraymisc import dequantize, quantize
+from annotator.uniformer.mmcv.image import imread, imwrite
+from annotator.uniformer.mmcv.utils import is_str
+
+
+def flowread(flow_or_path, quantize=False, concat_axis=0, *args, **kwargs):
+    """Read an optical flow map.
+
+    Args:
+        flow_or_path (ndarray or str): A flow map or filepath.
+        quantize (bool): whether to read quantized pair, if set to True,
+            remaining args will be passed to :func:`dequantize_flow`.
+        concat_axis (int): The axis that dx and dy are concatenated,
+            can be either 0 or 1. Ignored if quantize is False.
+
+    Returns:
+        ndarray: Optical flow represented as a (h, w, 2) numpy array
+    """
+    if isinstance(flow_or_path, np.ndarray):
+        if (flow_or_path.ndim != 3) or (flow_or_path.shape[-1] != 2):
+            raise ValueError(f'Invalid flow with shape {flow_or_path.shape}')
+        return flow_or_path
+    elif not is_str(flow_or_path):
+        raise TypeError(f'"flow_or_path" must be a filename or numpy array, '
+                        f'not {type(flow_or_path)}')
+
+    if not quantize:
+        with open(flow_or_path, 'rb') as f:
+            try:
+                header = f.read(4).decode('utf-8')
+            except Exception:
+                raise IOError(f'Invalid flow file: {flow_or_path}')
+            else:
+                if header != 'PIEH':
+                    raise IOError(f'Invalid flow file: {flow_or_path}, '
+                                  'header does not contain PIEH')
+
+            w = np.fromfile(f, np.int32, 1).squeeze()
+            h = np.fromfile(f, np.int32, 1).squeeze()
+            flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+    else:
+        assert concat_axis in [0, 1]
+        cat_flow = imread(flow_or_path, flag='unchanged')
+        if cat_flow.ndim != 2:
+            raise IOError(
+                f'{flow_or_path} is not a valid quantized flow file, '
+                f'its dimension is {cat_flow.ndim}.')
+        assert cat_flow.shape[concat_axis] % 2 == 0
+        dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+        flow = dequantize_flow(dx, dy, *args, **kwargs)
+
+    return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+    """Write optical flow to file.
+
+    If the flow is not quantized, it will be saved as a .flo file losslessly,
+    otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+    will be concatenated horizontally into a single image if quantize is True.)
+
+    Args:
+        flow (ndarray): (h, w, 2) array of optical flow.
+        filename (str): Output filepath.
+        quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+            images. If set to True, remaining args will be passed to
+            :func:`quantize_flow`.
+        concat_axis (int): The axis that dx and dy are concatenated,
+            can be either 0 or 1. Ignored if quantize is False.
+    """
+    if not quantize:
+        with open(filename, 'wb') as f:
+            f.write('PIEH'.encode('utf-8'))
+            np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+            flow = flow.astype(np.float32)
+            flow.tofile(f)
+            f.flush()
+    else:
+        assert concat_axis in [0, 1]
+        dx, dy = quantize_flow(flow, *args, **kwargs)
+        dxdy = np.concatenate((dx, dy), axis=concat_axis)
+        imwrite(dxdy, filename)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+    """Quantize flow to [0, 255].
+
+    After this step, the size of flow will be much smaller, and can be
+    dumped as jpeg images.
+
+    Args:
+        flow (ndarray): (h, w, 2) array of optical flow.
+        max_val (float): Maximum value of flow, values beyond
+                        [-max_val, max_val] will be truncated.
+        norm (bool): Whether to divide flow values by image width/height.
+
+    Returns:
+        tuple[ndarray]: Quantized dx and dy.
+    """
+    h, w, _ = flow.shape
+    dx = flow[..., 0]
+    dy = flow[..., 1]
+    if norm:
+        dx = dx / w  # avoid inplace operations
+        dy = dy / h
+    # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+    flow_comps = [
+        quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
+    ]
+    return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+    """Recover from quantized flow.
+
+    Args:
+        dx (ndarray): Quantized dx.
+        dy (ndarray): Quantized dy.
+        max_val (float): Maximum value used when quantizing.
+        denorm (bool): Whether to multiply flow values with width/height.
+
+    Returns:
+        ndarray: Dequantized flow.
+    """
+    assert dx.shape == dy.shape
+    assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+    dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+    if denorm:
+        dx *= dx.shape[1]
+        dy *= dx.shape[0]
+    flow = np.dstack((dx, dy))
+    return flow
+
+
+def flow_warp(img, flow, filling_value=0, interpolate_mode='nearest'):
+    """Use flow to warp img.
+
+    Args:
+        img (ndarray, float or uint8): Image to be warped.
+        flow (ndarray, float): Optical Flow.
+        filling_value (int): The missing pixels will be set with filling_value.
+        interpolate_mode (str): bilinear -> Bilinear Interpolation;
+                                nearest -> Nearest Neighbor.
+
+    Returns:
+        ndarray: Warped image with the same shape of img
+    """
+    warnings.warn('This function is just for prototyping and cannot '
+                  'guarantee the computational efficiency.')
+    assert flow.ndim == 3, 'Flow must be in 3D arrays.'
+    height = flow.shape[0]
+    width = flow.shape[1]
+    channels = img.shape[2]
+
+    output = np.ones(
+        (height, width, channels), dtype=img.dtype) * filling_value
+
+    grid = np.indices((height, width)).swapaxes(0, 1).swapaxes(1, 2)
+    dx = grid[:, :, 0] + flow[:, :, 1]
+    dy = grid[:, :, 1] + flow[:, :, 0]
+    sx = np.floor(dx).astype(int)
+    sy = np.floor(dy).astype(int)
+    valid = (sx >= 0) & (sx < height - 1) & (sy >= 0) & (sy < width - 1)
+
+    if interpolate_mode == 'nearest':
+        output[valid, :] = img[dx[valid].round().astype(int),
+                               dy[valid].round().astype(int), :]
+    elif interpolate_mode == 'bilinear':
+        # dirty walkround for integer positions
+        eps_ = 1e-6
+        dx, dy = dx + eps_, dy + eps_
+        left_top_ = img[np.floor(dx[valid]).astype(int),
+                        np.floor(dy[valid]).astype(int), :] * (
+                            np.ceil(dx[valid]) - dx[valid])[:, None] * (
+                                np.ceil(dy[valid]) - dy[valid])[:, None]
+        left_down_ = img[np.ceil(dx[valid]).astype(int),
+                         np.floor(dy[valid]).astype(int), :] * (
+                             dx[valid] - np.floor(dx[valid]))[:, None] * (
+                                 np.ceil(dy[valid]) - dy[valid])[:, None]
+        right_top_ = img[np.floor(dx[valid]).astype(int),
+                         np.ceil(dy[valid]).astype(int), :] * (
+                             np.ceil(dx[valid]) - dx[valid])[:, None] * (
+                                 dy[valid] - np.floor(dy[valid]))[:, None]
+        right_down_ = img[np.ceil(dx[valid]).astype(int),
+                          np.ceil(dy[valid]).astype(int), :] * (
+                              dx[valid] - np.floor(dx[valid]))[:, None] * (
+                                  dy[valid] - np.floor(dy[valid]))[:, None]
+        output[valid, :] = left_top_ + left_down_ + right_top_ + right_down_
+    else:
+        raise NotImplementedError(
+            'We only support interpolation modes of nearest and bilinear, '
+            f'but got {interpolate_mode}.')
+    return output.astype(img.dtype)
+
+
+def flow_from_bytes(content):
+    """Read dense optical flow from bytes.
+
+    .. note::
+        This load optical flow function works for FlyingChairs, FlyingThings3D,
+        Sintel, FlyingChairsOcc datasets, but cannot load the data from
+        ChairsSDHom.
+
+    Args:
+        content (bytes): Optical flow bytes got from files or other streams.
+
+    Returns:
+        ndarray: Loaded optical flow with the shape (H, W, 2).
+    """
+
+    # header in first 4 bytes
+    header = content[:4]
+    if header.decode('utf-8') != 'PIEH':
+        raise Exception('Flow file header does not contain PIEH')
+    # width in second 4 bytes
+    width = np.frombuffer(content[4:], np.int32, 1).squeeze()
+    # height in third 4 bytes
+    height = np.frombuffer(content[8:], np.int32, 1).squeeze()
+    # after first 12 bytes, all bytes are flow
+    flow = np.frombuffer(content[12:], np.float32, width * height * 2).reshape(
+        (height, width, 2))
+
+    return flow
+
+
+def sparse_flow_from_bytes(content):
+    """Read the optical flow in KITTI datasets from bytes.
+
+    This function is modified from RAFT load the `KITTI datasets
+    <https://github.com/princeton-vl/RAFT/blob/224320502d66c356d88e6c712f38129e60661e80/core/utils/frame_utils.py#L102>`_.
+
+    Args:
+        content (bytes): Optical flow bytes got from files or other streams.
+
+    Returns:
+        Tuple(ndarray, ndarray): Loaded optical flow with the shape (H, W, 2)
+            and flow valid mask with the shape (H, W).
+    """  # nopa
+
+    content = np.frombuffer(content, np.uint8)
+    flow = cv2.imdecode(content, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
+    flow = flow[:, :, ::-1].astype(np.float32)
+    # flow shape (H, W, 2) valid shape (H, W)
+    flow, valid = flow[:, :, :2], flow[:, :, 2]
+    flow = (flow - 2**15) / 64.0
+    return flow, valid
diff --git a/annotator/uniformer/mmcv/video/processing.py b/annotator/uniformer/mmcv/video/processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d90b96e0823d5f116755e7f498d25d17017224a
--- /dev/null
+++ b/annotator/uniformer/mmcv/video/processing.py
@@ -0,0 +1,160 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import subprocess
+import tempfile
+
+from annotator.uniformer.mmcv.utils import requires_executable
+
+
+@requires_executable('ffmpeg')
+def convert_video(in_file,
+                  out_file,
+                  print_cmd=False,
+                  pre_options='',
+                  **kwargs):
+    """Convert a video with ffmpeg.
+
+    This provides a general api to ffmpeg, the executed command is::
+
+        `ffmpeg -y <pre_options> -i <in_file> <options> <out_file>`
+
+    Options(kwargs) are mapped to ffmpeg commands with the following rules:
+
+    - key=val: "-key val"
+    - key=True: "-key"
+    - key=False: ""
+
+    Args:
+        in_file (str): Input video filename.
+        out_file (str): Output video filename.
+        pre_options (str): Options appears before "-i <in_file>".
+        print_cmd (bool): Whether to print the final ffmpeg command.
+    """
+    options = []
+    for k, v in kwargs.items():
+        if isinstance(v, bool):
+            if v:
+                options.append(f'-{k}')
+        elif k == 'log_level':
+            assert v in [
+                'quiet', 'panic', 'fatal', 'error', 'warning', 'info',
+                'verbose', 'debug', 'trace'
+            ]
+            options.append(f'-loglevel {v}')
+        else:
+            options.append(f'-{k} {v}')
+    cmd = f'ffmpeg -y {pre_options} -i {in_file} {" ".join(options)} ' \
+          f'{out_file}'
+    if print_cmd:
+        print(cmd)
+    subprocess.call(cmd, shell=True)
+
+
+@requires_executable('ffmpeg')
+def resize_video(in_file,
+                 out_file,
+                 size=None,
+                 ratio=None,
+                 keep_ar=False,
+                 log_level='info',
+                 print_cmd=False):
+    """Resize a video.
+
+    Args:
+        in_file (str): Input video filename.
+        out_file (str): Output video filename.
+        size (tuple): Expected size (w, h), eg, (320, 240) or (320, -1).
+        ratio (tuple or float): Expected resize ratio, (2, 0.5) means
+            (w*2, h*0.5).
+        keep_ar (bool): Whether to keep original aspect ratio.
+        log_level (str): Logging level of ffmpeg.
+        print_cmd (bool): Whether to print the final ffmpeg command.
+    """
+    if size is None and ratio is None:
+        raise ValueError('expected size or ratio must be specified')
+    if size is not None and ratio is not None:
+        raise ValueError('size and ratio cannot be specified at the same time')
+    options = {'log_level': log_level}
+    if size:
+        if not keep_ar:
+            options['vf'] = f'scale={size[0]}:{size[1]}'
+        else:
+            options['vf'] = f'scale=w={size[0]}:h={size[1]}:' \
+                            'force_original_aspect_ratio=decrease'
+    else:
+        if not isinstance(ratio, tuple):
+            ratio = (ratio, ratio)
+        options['vf'] = f'scale="trunc(iw*{ratio[0]}):trunc(ih*{ratio[1]})"'
+    convert_video(in_file, out_file, print_cmd, **options)
+
+
+@requires_executable('ffmpeg')
+def cut_video(in_file,
+              out_file,
+              start=None,
+              end=None,
+              vcodec=None,
+              acodec=None,
+              log_level='info',
+              print_cmd=False):
+    """Cut a clip from a video.
+
+    Args:
+        in_file (str): Input video filename.
+        out_file (str): Output video filename.
+        start (None or float): Start time (in seconds).
+        end (None or float): End time (in seconds).
+        vcodec (None or str): Output video codec, None for unchanged.
+        acodec (None or str): Output audio codec, None for unchanged.
+        log_level (str): Logging level of ffmpeg.
+        print_cmd (bool): Whether to print the final ffmpeg command.
+    """
+    options = {'log_level': log_level}
+    if vcodec is None:
+        options['vcodec'] = 'copy'
+    if acodec is None:
+        options['acodec'] = 'copy'
+    if start:
+        options['ss'] = start
+    else:
+        start = 0
+    if end:
+        options['t'] = end - start
+    convert_video(in_file, out_file, print_cmd, **options)
+
+
+@requires_executable('ffmpeg')
+def concat_video(video_list,
+                 out_file,
+                 vcodec=None,
+                 acodec=None,
+                 log_level='info',
+                 print_cmd=False):
+    """Concatenate multiple videos into a single one.
+
+    Args:
+        video_list (list): A list of video filenames
+        out_file (str): Output video filename
+        vcodec (None or str): Output video codec, None for unchanged
+        acodec (None or str): Output audio codec, None for unchanged
+        log_level (str): Logging level of ffmpeg.
+        print_cmd (bool): Whether to print the final ffmpeg command.
+    """
+    tmp_filehandler, tmp_filename = tempfile.mkstemp(suffix='.txt', text=True)
+    with open(tmp_filename, 'w') as f:
+        for filename in video_list:
+            f.write(f'file {osp.abspath(filename)}\n')
+    options = {'log_level': log_level}
+    if vcodec is None:
+        options['vcodec'] = 'copy'
+    if acodec is None:
+        options['acodec'] = 'copy'
+    convert_video(
+        tmp_filename,
+        out_file,
+        print_cmd,
+        pre_options='-f concat -safe 0',
+        **options)
+    os.close(tmp_filehandler)
+    os.remove(tmp_filename)
diff --git a/annotator/uniformer/mmcv/visualization/__init__.py b/annotator/uniformer/mmcv/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..835df136bdcf69348281d22914d41aa84cdf92b1
--- /dev/null
+++ b/annotator/uniformer/mmcv/visualization/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .color import Color, color_val
+from .image import imshow, imshow_bboxes, imshow_det_bboxes
+from .optflow import flow2rgb, flowshow, make_color_wheel
+
+__all__ = [
+    'Color', 'color_val', 'imshow', 'imshow_bboxes', 'imshow_det_bboxes',
+    'flowshow', 'flow2rgb', 'make_color_wheel'
+]
diff --git a/annotator/uniformer/mmcv/visualization/color.py b/annotator/uniformer/mmcv/visualization/color.py
new file mode 100644
index 0000000000000000000000000000000000000000..9041e0e6b7581c3356795d6a3c5e84667c88f025
--- /dev/null
+++ b/annotator/uniformer/mmcv/visualization/color.py
@@ -0,0 +1,51 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from enum import Enum
+
+import numpy as np
+
+from annotator.uniformer.mmcv.utils import is_str
+
+
+class Color(Enum):
+    """An enum that defines common colors.
+
+    Contains red, green, blue, cyan, yellow, magenta, white and black.
+    """
+    red = (0, 0, 255)
+    green = (0, 255, 0)
+    blue = (255, 0, 0)
+    cyan = (255, 255, 0)
+    yellow = (0, 255, 255)
+    magenta = (255, 0, 255)
+    white = (255, 255, 255)
+    black = (0, 0, 0)
+
+
+def color_val(color):
+    """Convert various input to color tuples.
+
+    Args:
+        color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
+
+    Returns:
+        tuple[int]: A tuple of 3 integers indicating BGR channels.
+    """
+    if is_str(color):
+        return Color[color].value
+    elif isinstance(color, Color):
+        return color.value
+    elif isinstance(color, tuple):
+        assert len(color) == 3
+        for channel in color:
+            assert 0 <= channel <= 255
+        return color
+    elif isinstance(color, int):
+        assert 0 <= color <= 255
+        return color, color, color
+    elif isinstance(color, np.ndarray):
+        assert color.ndim == 1 and color.size == 3
+        assert np.all((color >= 0) & (color <= 255))
+        color = color.astype(np.uint8)
+        return tuple(color)
+    else:
+        raise TypeError(f'Invalid type for color: {type(color)}')
diff --git a/annotator/uniformer/mmcv/visualization/image.py b/annotator/uniformer/mmcv/visualization/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..61a56c75b67f593c298408462c63c0468be8e276
--- /dev/null
+++ b/annotator/uniformer/mmcv/visualization/image.py
@@ -0,0 +1,152 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import cv2
+import numpy as np
+
+from annotator.uniformer.mmcv.image import imread, imwrite
+from .color import color_val
+
+
+def imshow(img, win_name='', wait_time=0):
+    """Show an image.
+
+    Args:
+        img (str or ndarray): The image to be displayed.
+        win_name (str): The window name.
+        wait_time (int): Value of waitKey param.
+    """
+    cv2.imshow(win_name, imread(img))
+    if wait_time == 0:  # prevent from hanging if windows was closed
+        while True:
+            ret = cv2.waitKey(1)
+
+            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
+            # if user closed window or if some key pressed
+            if closed or ret != -1:
+                break
+    else:
+        ret = cv2.waitKey(wait_time)
+
+
+def imshow_bboxes(img,
+                  bboxes,
+                  colors='green',
+                  top_k=-1,
+                  thickness=1,
+                  show=True,
+                  win_name='',
+                  wait_time=0,
+                  out_file=None):
+    """Draw bboxes on an image.
+
+    Args:
+        img (str or ndarray): The image to be displayed.
+        bboxes (list or ndarray): A list of ndarray of shape (k, 4).
+        colors (list[str or tuple or Color]): A list of colors.
+        top_k (int): Plot the first k bboxes only if set positive.
+        thickness (int): Thickness of lines.
+        show (bool): Whether to show the image.
+        win_name (str): The window name.
+        wait_time (int): Value of waitKey param.
+        out_file (str, optional): The filename to write the image.
+
+    Returns:
+        ndarray: The image with bboxes drawn on it.
+    """
+    img = imread(img)
+    img = np.ascontiguousarray(img)
+
+    if isinstance(bboxes, np.ndarray):
+        bboxes = [bboxes]
+    if not isinstance(colors, list):
+        colors = [colors for _ in range(len(bboxes))]
+    colors = [color_val(c) for c in colors]
+    assert len(bboxes) == len(colors)
+
+    for i, _bboxes in enumerate(bboxes):
+        _bboxes = _bboxes.astype(np.int32)
+        if top_k <= 0:
+            _top_k = _bboxes.shape[0]
+        else:
+            _top_k = min(top_k, _bboxes.shape[0])
+        for j in range(_top_k):
+            left_top = (_bboxes[j, 0], _bboxes[j, 1])
+            right_bottom = (_bboxes[j, 2], _bboxes[j, 3])
+            cv2.rectangle(
+                img, left_top, right_bottom, colors[i], thickness=thickness)
+
+    if show:
+        imshow(img, win_name, wait_time)
+    if out_file is not None:
+        imwrite(img, out_file)
+    return img
+
+
+def imshow_det_bboxes(img,
+                      bboxes,
+                      labels,
+                      class_names=None,
+                      score_thr=0,
+                      bbox_color='green',
+                      text_color='green',
+                      thickness=1,
+                      font_scale=0.5,
+                      show=True,
+                      win_name='',
+                      wait_time=0,
+                      out_file=None):
+    """Draw bboxes and class labels (with scores) on an image.
+
+    Args:
+        img (str or ndarray): The image to be displayed.
+        bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
+            (n, 5).
+        labels (ndarray): Labels of bboxes.
+        class_names (list[str]): Names of each classes.
+        score_thr (float): Minimum score of bboxes to be shown.
+        bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
+        text_color (str or tuple or :obj:`Color`): Color of texts.
+        thickness (int): Thickness of lines.
+        font_scale (float): Font scales of texts.
+        show (bool): Whether to show the image.
+        win_name (str): The window name.
+        wait_time (int): Value of waitKey param.
+        out_file (str or None): The filename to write the image.
+
+    Returns:
+        ndarray: The image with bboxes drawn on it.
+    """
+    assert bboxes.ndim == 2
+    assert labels.ndim == 1
+    assert bboxes.shape[0] == labels.shape[0]
+    assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5
+    img = imread(img)
+    img = np.ascontiguousarray(img)
+
+    if score_thr > 0:
+        assert bboxes.shape[1] == 5
+        scores = bboxes[:, -1]
+        inds = scores > score_thr
+        bboxes = bboxes[inds, :]
+        labels = labels[inds]
+
+    bbox_color = color_val(bbox_color)
+    text_color = color_val(text_color)
+
+    for bbox, label in zip(bboxes, labels):
+        bbox_int = bbox.astype(np.int32)
+        left_top = (bbox_int[0], bbox_int[1])
+        right_bottom = (bbox_int[2], bbox_int[3])
+        cv2.rectangle(
+            img, left_top, right_bottom, bbox_color, thickness=thickness)
+        label_text = class_names[
+            label] if class_names is not None else f'cls {label}'
+        if len(bbox) > 4:
+            label_text += f'|{bbox[-1]:.02f}'
+        cv2.putText(img, label_text, (bbox_int[0], bbox_int[1] - 2),
+                    cv2.FONT_HERSHEY_COMPLEX, font_scale, text_color)
+
+    if show:
+        imshow(img, win_name, wait_time)
+    if out_file is not None:
+        imwrite(img, out_file)
+    return img
diff --git a/annotator/uniformer/mmcv/visualization/optflow.py b/annotator/uniformer/mmcv/visualization/optflow.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3870c700f7c946177ee5d536ce3f6c814a77ce7
--- /dev/null
+++ b/annotator/uniformer/mmcv/visualization/optflow.py
@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from __future__ import division
+
+import numpy as np
+
+from annotator.uniformer.mmcv.image import rgb2bgr
+from annotator.uniformer.mmcv.video import flowread
+from .image import imshow
+
+
+def flowshow(flow, win_name='', wait_time=0):
+    """Show optical flow.
+
+    Args:
+        flow (ndarray or str): The optical flow to be displayed.
+        win_name (str): The window name.
+        wait_time (int): Value of waitKey param.
+    """
+    flow = flowread(flow)
+    flow_img = flow2rgb(flow)
+    imshow(rgb2bgr(flow_img), win_name, wait_time)
+
+
+def flow2rgb(flow, color_wheel=None, unknown_thr=1e6):
+    """Convert flow map to RGB image.
+
+    Args:
+        flow (ndarray): Array of optical flow.
+        color_wheel (ndarray or None): Color wheel used to map flow field to
+            RGB colorspace. Default color wheel will be used if not specified.
+        unknown_thr (str): Values above this threshold will be marked as
+            unknown and thus ignored.
+
+    Returns:
+        ndarray: RGB image that can be visualized.
+    """
+    assert flow.ndim == 3 and flow.shape[-1] == 2
+    if color_wheel is None:
+        color_wheel = make_color_wheel()
+    assert color_wheel.ndim == 2 and color_wheel.shape[1] == 3
+    num_bins = color_wheel.shape[0]
+
+    dx = flow[:, :, 0].copy()
+    dy = flow[:, :, 1].copy()
+
+    ignore_inds = (
+        np.isnan(dx) | np.isnan(dy) | (np.abs(dx) > unknown_thr) |
+        (np.abs(dy) > unknown_thr))
+    dx[ignore_inds] = 0
+    dy[ignore_inds] = 0
+
+    rad = np.sqrt(dx**2 + dy**2)
+    if np.any(rad > np.finfo(float).eps):
+        max_rad = np.max(rad)
+        dx /= max_rad
+        dy /= max_rad
+
+    rad = np.sqrt(dx**2 + dy**2)
+    angle = np.arctan2(-dy, -dx) / np.pi
+
+    bin_real = (angle + 1) / 2 * (num_bins - 1)
+    bin_left = np.floor(bin_real).astype(int)
+    bin_right = (bin_left + 1) % num_bins
+    w = (bin_real - bin_left.astype(np.float32))[..., None]
+    flow_img = (1 -
+                w) * color_wheel[bin_left, :] + w * color_wheel[bin_right, :]
+    small_ind = rad <= 1
+    flow_img[small_ind] = 1 - rad[small_ind, None] * (1 - flow_img[small_ind])
+    flow_img[np.logical_not(small_ind)] *= 0.75
+
+    flow_img[ignore_inds, :] = 0
+
+    return flow_img
+
+
+def make_color_wheel(bins=None):
+    """Build a color wheel.
+
+    Args:
+        bins(list or tuple, optional): Specify the number of bins for each
+            color range, corresponding to six ranges: red -> yellow,
+            yellow -> green, green -> cyan, cyan -> blue, blue -> magenta,
+            magenta -> red. [15, 6, 4, 11, 13, 6] is used for default
+            (see Middlebury).
+
+    Returns:
+        ndarray: Color wheel of shape (total_bins, 3).
+    """
+    if bins is None:
+        bins = [15, 6, 4, 11, 13, 6]
+    assert len(bins) == 6
+
+    RY, YG, GC, CB, BM, MR = tuple(bins)
+
+    ry = [1, np.arange(RY) / RY, 0]
+    yg = [1 - np.arange(YG) / YG, 1, 0]
+    gc = [0, 1, np.arange(GC) / GC]
+    cb = [0, 1 - np.arange(CB) / CB, 1]
+    bm = [np.arange(BM) / BM, 0, 1]
+    mr = [1, 0, 1 - np.arange(MR) / MR]
+
+    num_bins = RY + YG + GC + CB + BM + MR
+
+    color_wheel = np.zeros((3, num_bins), dtype=np.float32)
+
+    col = 0
+    for i, color in enumerate([ry, yg, gc, cb, bm, mr]):
+        for j in range(3):
+            color_wheel[j, col:col + bins[i]] = color[j]
+        col += bins[i]
+
+    return color_wheel.T
diff --git a/annotator/uniformer/mmcv_custom/__init__.py b/annotator/uniformer/mmcv_custom/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b958738b9fd93bfcec239c550df1d9a44b8c536
--- /dev/null
+++ b/annotator/uniformer/mmcv_custom/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding: utf-8 -*-
+
+from .checkpoint import load_checkpoint
+
+__all__ = ['load_checkpoint']
\ No newline at end of file
diff --git a/annotator/uniformer/mmcv_custom/checkpoint.py b/annotator/uniformer/mmcv_custom/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..19b87fef0a52d31babcdb3edb8f3089b6420173f
--- /dev/null
+++ b/annotator/uniformer/mmcv_custom/checkpoint.py
@@ -0,0 +1,500 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+from torch.nn import functional as F
+
+import annotator.uniformer.mmcv as mmcv
+from annotator.uniformer.mmcv.fileio import FileClient
+from annotator.uniformer.mmcv.fileio import load as load_file
+from annotator.uniformer.mmcv.parallel import is_module_wrapper
+from annotator.uniformer.mmcv.utils import mkdir_or_exist
+from annotator.uniformer.mmcv.runner import get_dist_info
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+    mmcv_home = os.path.expanduser(
+        os.getenv(
+            ENV_MMCV_HOME,
+            os.path.join(
+                os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+    mkdir_or_exist(mmcv_home)
+    return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+    """Load state_dict to a module.
+
+    This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+    Default value for ``strict`` is set to ``False`` and the message for
+    param mismatch will be shown even if strict is False.
+
+    Args:
+        module (Module): Module that receives the state_dict.
+        state_dict (OrderedDict): Weights.
+        strict (bool): whether to strictly enforce that the keys
+            in :attr:`state_dict` match the keys returned by this module's
+            :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+        logger (:obj:`logging.Logger`, optional): Logger to log the error
+            message. If not specified, print function will be used.
+    """
+    unexpected_keys = []
+    all_missing_keys = []
+    err_msg = []
+
+    metadata = getattr(state_dict, '_metadata', None)
+    state_dict = state_dict.copy()
+    if metadata is not None:
+        state_dict._metadata = metadata
+
+    # use _load_from_state_dict to enable checkpoint version control
+    def load(module, prefix=''):
+        # recursively check parallel module in case that the model has a
+        # complicated structure, e.g., nn.Module(nn.Module(DDP))
+        if is_module_wrapper(module):
+            module = module.module
+        local_metadata = {} if metadata is None else metadata.get(
+            prefix[:-1], {})
+        module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+                                     all_missing_keys, unexpected_keys,
+                                     err_msg)
+        for name, child in module._modules.items():
+            if child is not None:
+                load(child, prefix + name + '.')
+
+    load(module)
+    load = None  # break load->load reference cycle
+
+    # ignore "num_batches_tracked" of BN layers
+    missing_keys = [
+        key for key in all_missing_keys if 'num_batches_tracked' not in key
+    ]
+
+    if unexpected_keys:
+        err_msg.append('unexpected key in source '
+                       f'state_dict: {", ".join(unexpected_keys)}\n')
+    if missing_keys:
+        err_msg.append(
+            f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+    rank, _ = get_dist_info()
+    if len(err_msg) > 0 and rank == 0:
+        err_msg.insert(
+            0, 'The model and loaded state dict do not match exactly\n')
+        err_msg = '\n'.join(err_msg)
+        if strict:
+            raise RuntimeError(err_msg)
+        elif logger is not None:
+            logger.warning(err_msg)
+        else:
+            print(err_msg)
+
+
+def load_url_dist(url, model_dir=None):
+    """In distributed setting, this function only download checkpoint at local
+    rank 0."""
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    if rank == 0:
+        checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+    return checkpoint
+
+
+def load_pavimodel_dist(model_path, map_location=None):
+    """In distributed setting, this function only download checkpoint at local
+    rank 0."""
+    try:
+        from pavi import modelcloud
+    except ImportError:
+        raise ImportError(
+            'Please install pavi to load checkpoint from modelcloud.')
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    if rank == 0:
+        model = modelcloud.get(model_path)
+        with TemporaryDirectory() as tmp_dir:
+            downloaded_file = osp.join(tmp_dir, model.name)
+            model.download(downloaded_file)
+            checkpoint = torch.load(downloaded_file, map_location=map_location)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            model = modelcloud.get(model_path)
+            with TemporaryDirectory() as tmp_dir:
+                downloaded_file = osp.join(tmp_dir, model.name)
+                model.download(downloaded_file)
+                checkpoint = torch.load(
+                    downloaded_file, map_location=map_location)
+    return checkpoint
+
+
+def load_fileclient_dist(filename, backend, map_location):
+    """In distributed setting, this function only download checkpoint at local
+    rank 0."""
+    rank, world_size = get_dist_info()
+    rank = int(os.environ.get('LOCAL_RANK', rank))
+    allowed_backends = ['ceph']
+    if backend not in allowed_backends:
+        raise ValueError(f'Load from Backend {backend} is not supported.')
+    if rank == 0:
+        fileclient = FileClient(backend=backend)
+        buffer = io.BytesIO(fileclient.get(filename))
+        checkpoint = torch.load(buffer, map_location=map_location)
+    if world_size > 1:
+        torch.distributed.barrier()
+        if rank > 0:
+            fileclient = FileClient(backend=backend)
+            buffer = io.BytesIO(fileclient.get(filename))
+            checkpoint = torch.load(buffer, map_location=map_location)
+    return checkpoint
+
+
+def get_torchvision_models():
+    model_urls = dict()
+    for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+        if ispkg:
+            continue
+        _zoo = import_module(f'torchvision.models.{name}')
+        if hasattr(_zoo, 'model_urls'):
+            _urls = getattr(_zoo, 'model_urls')
+            model_urls.update(_urls)
+    return model_urls
+
+
+def get_external_models():
+    mmcv_home = _get_mmcv_home()
+    default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+    default_urls = load_file(default_json_path)
+    assert isinstance(default_urls, dict)
+    external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+    if osp.exists(external_json_path):
+        external_urls = load_file(external_json_path)
+        assert isinstance(external_urls, dict)
+        default_urls.update(external_urls)
+
+    return default_urls
+
+
+def get_mmcls_models():
+    mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+    mmcls_urls = load_file(mmcls_json_path)
+
+    return mmcls_urls
+
+
+def get_deprecated_model_names():
+    deprecate_json_path = osp.join(mmcv.__path__[0],
+                                   'model_zoo/deprecated.json')
+    deprecate_urls = load_file(deprecate_json_path)
+    assert isinstance(deprecate_urls, dict)
+
+    return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+    state_dict = checkpoint['state_dict']
+    new_state_dict = OrderedDict()
+    for k, v in state_dict.items():
+        if k.startswith('backbone.'):
+            new_state_dict[k[9:]] = v
+    new_checkpoint = dict(state_dict=new_state_dict)
+
+    return new_checkpoint
+
+
+def _load_checkpoint(filename, map_location=None):
+    """Load checkpoint from somewhere (modelzoo, file, url).
+
+    Args:
+        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+            details.
+        map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+    Returns:
+        dict | OrderedDict: The loaded checkpoint. It can be either an
+            OrderedDict storing model weights or a dict containing other
+            information, which depends on the checkpoint.
+    """
+    if filename.startswith('modelzoo://'):
+        warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+                      'use "torchvision://" instead')
+        model_urls = get_torchvision_models()
+        model_name = filename[11:]
+        checkpoint = load_url_dist(model_urls[model_name])
+    elif filename.startswith('torchvision://'):
+        model_urls = get_torchvision_models()
+        model_name = filename[14:]
+        checkpoint = load_url_dist(model_urls[model_name])
+    elif filename.startswith('open-mmlab://'):
+        model_urls = get_external_models()
+        model_name = filename[13:]
+        deprecated_urls = get_deprecated_model_names()
+        if model_name in deprecated_urls:
+            warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
+                          f'of open-mmlab://{deprecated_urls[model_name]}')
+            model_name = deprecated_urls[model_name]
+        model_url = model_urls[model_name]
+        # check if is url
+        if model_url.startswith(('http://', 'https://')):
+            checkpoint = load_url_dist(model_url)
+        else:
+            filename = osp.join(_get_mmcv_home(), model_url)
+            if not osp.isfile(filename):
+                raise IOError(f'{filename} is not a checkpoint file')
+            checkpoint = torch.load(filename, map_location=map_location)
+    elif filename.startswith('mmcls://'):
+        model_urls = get_mmcls_models()
+        model_name = filename[8:]
+        checkpoint = load_url_dist(model_urls[model_name])
+        checkpoint = _process_mmcls_checkpoint(checkpoint)
+    elif filename.startswith(('http://', 'https://')):
+        checkpoint = load_url_dist(filename)
+    elif filename.startswith('pavi://'):
+        model_path = filename[7:]
+        checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
+    elif filename.startswith('s3://'):
+        checkpoint = load_fileclient_dist(
+            filename, backend='ceph', map_location=map_location)
+    else:
+        if not osp.isfile(filename):
+            raise IOError(f'{filename} is not a checkpoint file')
+        checkpoint = torch.load(filename, map_location=map_location)
+    return checkpoint
+
+
+def load_checkpoint(model,
+                    filename,
+                    map_location='cpu',
+                    strict=False,
+                    logger=None):
+    """Load checkpoint from a file or URI.
+
+    Args:
+        model (Module): Module to load checkpoint.
+        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+            details.
+        map_location (str): Same as :func:`torch.load`.
+        strict (bool): Whether to allow different params for the model and
+            checkpoint.
+        logger (:mod:`logging.Logger` or None): The logger for error message.
+
+    Returns:
+        dict or OrderedDict: The loaded checkpoint.
+    """
+    checkpoint = _load_checkpoint(filename, map_location)
+    # OrderedDict is a subclass of dict
+    if not isinstance(checkpoint, dict):
+        raise RuntimeError(
+            f'No state_dict found in checkpoint file {filename}')
+    # get state_dict from checkpoint
+    if 'state_dict' in checkpoint:
+        state_dict = checkpoint['state_dict']
+    elif 'model' in checkpoint:
+        state_dict = checkpoint['model']
+    else:
+        state_dict = checkpoint
+    # strip prefix of state_dict
+    if list(state_dict.keys())[0].startswith('module.'):
+        state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+    # for MoBY, load model of online branch
+    if sorted(list(state_dict.keys()))[0].startswith('encoder'):
+        state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
+
+    # reshape absolute position embedding
+    if state_dict.get('absolute_pos_embed') is not None:
+        absolute_pos_embed = state_dict['absolute_pos_embed']
+        N1, L, C1 = absolute_pos_embed.size()
+        N2, C2, H, W = model.absolute_pos_embed.size()
+        if N1 != N2 or C1 != C2 or L != H*W:
+            logger.warning("Error in loading absolute_pos_embed, pass")
+        else:
+            state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
+
+    # interpolate position bias table if needed
+    relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
+    for table_key in relative_position_bias_table_keys:
+        table_pretrained = state_dict[table_key]
+        table_current = model.state_dict()[table_key]
+        L1, nH1 = table_pretrained.size()
+        L2, nH2 = table_current.size()
+        if nH1 != nH2:
+            logger.warning(f"Error in loading {table_key}, pass")
+        else:
+            if L1 != L2:
+                S1 = int(L1 ** 0.5)
+                S2 = int(L2 ** 0.5)
+                table_pretrained_resized = F.interpolate(
+                     table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
+                     size=(S2, S2), mode='bicubic')
+                state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
+
+    # load state_dict
+    load_state_dict(model, state_dict, strict, logger)
+    return checkpoint
+
+
+def weights_to_cpu(state_dict):
+    """Copy a model state_dict to cpu.
+
+    Args:
+        state_dict (OrderedDict): Model weights on GPU.
+
+    Returns:
+        OrderedDict: Model weights on GPU.
+    """
+    state_dict_cpu = OrderedDict()
+    for key, val in state_dict.items():
+        state_dict_cpu[key] = val.cpu()
+    return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+    """Saves module state to `destination` dictionary.
+
+    This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+    Args:
+        module (nn.Module): The module to generate state_dict.
+        destination (dict): A dict where state will be stored.
+        prefix (str): The prefix for parameters and buffers used in this
+            module.
+    """
+    for name, param in module._parameters.items():
+        if param is not None:
+            destination[prefix + name] = param if keep_vars else param.detach()
+    for name, buf in module._buffers.items():
+        # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+        if buf is not None:
+            destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+    """Returns a dictionary containing a whole state of the module.
+
+    Both parameters and persistent buffers (e.g. running averages) are
+    included. Keys are corresponding parameter and buffer names.
+
+    This method is modified from :meth:`torch.nn.Module.state_dict` to
+    recursively check parallel module in case that the model has a complicated
+    structure, e.g., nn.Module(nn.Module(DDP)).
+
+    Args:
+        module (nn.Module): The module to generate state_dict.
+        destination (OrderedDict): Returned dict for the state of the
+            module.
+        prefix (str): Prefix of the key.
+        keep_vars (bool): Whether to keep the variable property of the
+            parameters. Default: False.
+
+    Returns:
+        dict: A dictionary containing a whole state of the module.
+    """
+    # recursively check parallel module in case that the model has a
+    # complicated structure, e.g., nn.Module(nn.Module(DDP))
+    if is_module_wrapper(module):
+        module = module.module
+
+    # below is the same as torch.nn.Module.state_dict()
+    if destination is None:
+        destination = OrderedDict()
+        destination._metadata = OrderedDict()
+    destination._metadata[prefix[:-1]] = local_metadata = dict(
+        version=module._version)
+    _save_to_state_dict(module, destination, prefix, keep_vars)
+    for name, child in module._modules.items():
+        if child is not None:
+            get_state_dict(
+                child, destination, prefix + name + '.', keep_vars=keep_vars)
+    for hook in module._state_dict_hooks.values():
+        hook_result = hook(module, destination, prefix, local_metadata)
+        if hook_result is not None:
+            destination = hook_result
+    return destination
+
+
+def save_checkpoint(model, filename, optimizer=None, meta=None):
+    """Save checkpoint to file.
+
+    The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+    ``optimizer``. By default ``meta`` will contain version and time info.
+
+    Args:
+        model (Module): Module whose params are to be saved.
+        filename (str): Checkpoint filename.
+        optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+        meta (dict, optional): Metadata to be saved in checkpoint.
+    """
+    if meta is None:
+        meta = {}
+    elif not isinstance(meta, dict):
+        raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+    meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+    if is_module_wrapper(model):
+        model = model.module
+
+    if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+        # save class name to the meta
+        meta.update(CLASSES=model.CLASSES)
+
+    checkpoint = {
+        'meta': meta,
+        'state_dict': weights_to_cpu(get_state_dict(model))
+    }
+    # save optimizer state dict in the checkpoint
+    if isinstance(optimizer, Optimizer):
+        checkpoint['optimizer'] = optimizer.state_dict()
+    elif isinstance(optimizer, dict):
+        checkpoint['optimizer'] = {}
+        for name, optim in optimizer.items():
+            checkpoint['optimizer'][name] = optim.state_dict()
+
+    if filename.startswith('pavi://'):
+        try:
+            from pavi import modelcloud
+            from pavi.exception import NodeNotFoundError
+        except ImportError:
+            raise ImportError(
+                'Please install pavi to load checkpoint from modelcloud.')
+        model_path = filename[7:]
+        root = modelcloud.Folder()
+        model_dir, model_name = osp.split(model_path)
+        try:
+            model = modelcloud.get(model_dir)
+        except NodeNotFoundError:
+            model = root.create_training_model(model_dir)
+        with TemporaryDirectory() as tmp_dir:
+            checkpoint_file = osp.join(tmp_dir, model_name)
+            with open(checkpoint_file, 'wb') as f:
+                torch.save(checkpoint, f)
+                f.flush()
+            model.create_file(checkpoint_file, name=model_name)
+    else:
+        mmcv.mkdir_or_exist(osp.dirname(filename))
+        # immediately flush buffer
+        with open(filename, 'wb') as f:
+            torch.save(checkpoint, f)
+            f.flush()
\ No newline at end of file
diff --git a/annotator/uniformer/mmseg/apis/__init__.py b/annotator/uniformer/mmseg/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..170724be38de42daf2bc1a1910e181d68818f165
--- /dev/null
+++ b/annotator/uniformer/mmseg/apis/__init__.py
@@ -0,0 +1,9 @@
+from .inference import inference_segmentor, init_segmentor, show_result_pyplot
+from .test import multi_gpu_test, single_gpu_test
+from .train import get_root_logger, set_random_seed, train_segmentor
+
+__all__ = [
+    'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor',
+    'inference_segmentor', 'multi_gpu_test', 'single_gpu_test',
+    'show_result_pyplot'
+]
diff --git a/annotator/uniformer/mmseg/apis/inference.py b/annotator/uniformer/mmseg/apis/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..90bc1c0c68525734bd6793f07c15fe97d3c8342c
--- /dev/null
+++ b/annotator/uniformer/mmseg/apis/inference.py
@@ -0,0 +1,136 @@
+import matplotlib.pyplot as plt
+import annotator.uniformer.mmcv as mmcv
+import torch
+from annotator.uniformer.mmcv.parallel import collate, scatter
+from annotator.uniformer.mmcv.runner import load_checkpoint
+
+from annotator.uniformer.mmseg.datasets.pipelines import Compose
+from annotator.uniformer.mmseg.models import build_segmentor
+
+
+def init_segmentor(config, checkpoint=None, device='cuda:0'):
+    """Initialize a segmentor from config file.
+
+    Args:
+        config (str or :obj:`mmcv.Config`): Config file path or the config
+            object.
+        checkpoint (str, optional): Checkpoint path. If left as None, the model
+            will not load any weights.
+        device (str, optional) CPU/CUDA device option. Default 'cuda:0'.
+            Use 'cpu' for loading model on CPU.
+    Returns:
+        nn.Module: The constructed segmentor.
+    """
+    if isinstance(config, str):
+        config = mmcv.Config.fromfile(config)
+    elif not isinstance(config, mmcv.Config):
+        raise TypeError('config must be a filename or Config object, '
+                        'but got {}'.format(type(config)))
+    config.model.pretrained = None
+    config.model.train_cfg = None
+    model = build_segmentor(config.model, test_cfg=config.get('test_cfg'))
+    if checkpoint is not None:
+        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
+        model.CLASSES = checkpoint['meta']['CLASSES']
+        model.PALETTE = checkpoint['meta']['PALETTE']
+    model.cfg = config  # save the config in the model for convenience
+    model.to(device)
+    model.eval()
+    return model
+
+
+class LoadImage:
+    """A simple pipeline to load image."""
+
+    def __call__(self, results):
+        """Call function to load images into results.
+
+        Args:
+            results (dict): A result dict contains the file name
+                of the image to be read.
+
+        Returns:
+            dict: ``results`` will be returned containing loaded image.
+        """
+
+        if isinstance(results['img'], str):
+            results['filename'] = results['img']
+            results['ori_filename'] = results['img']
+        else:
+            results['filename'] = None
+            results['ori_filename'] = None
+        img = mmcv.imread(results['img'])
+        results['img'] = img
+        results['img_shape'] = img.shape
+        results['ori_shape'] = img.shape
+        return results
+
+
+def inference_segmentor(model, img):
+    """Inference image(s) with the segmentor.
+
+    Args:
+        model (nn.Module): The loaded segmentor.
+        imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
+            images.
+
+    Returns:
+        (list[Tensor]): The segmentation result.
+    """
+    cfg = model.cfg
+    device = next(model.parameters()).device  # model device
+    # build the data pipeline
+    test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
+    test_pipeline = Compose(test_pipeline)
+    # prepare data
+    data = dict(img=img)
+    data = test_pipeline(data)
+    data = collate([data], samples_per_gpu=1)
+    if next(model.parameters()).is_cuda:
+        # scatter to specified GPU
+        data = scatter(data, [device])[0]
+    else:
+        data['img_metas'] = [i.data[0] for i in data['img_metas']]
+
+    # forward the model
+    with torch.no_grad():
+        result = model(return_loss=False, rescale=True, **data)
+    return result
+
+
+def show_result_pyplot(model,
+                       img,
+                       result,
+                       palette=None,
+                       fig_size=(15, 10),
+                       opacity=0.5,
+                       title='',
+                       block=True):
+    """Visualize the segmentation results on the image.
+
+    Args:
+        model (nn.Module): The loaded segmentor.
+        img (str or np.ndarray): Image filename or loaded image.
+        result (list): The segmentation result.
+        palette (list[list[int]]] | None): The palette of segmentation
+            map. If None is given, random palette will be generated.
+            Default: None
+        fig_size (tuple): Figure size of the pyplot figure.
+        opacity(float): Opacity of painted segmentation map.
+            Default 0.5.
+            Must be in (0, 1] range.
+        title (str): The title of pyplot figure.
+            Default is ''.
+        block (bool): Whether to block the pyplot figure.
+            Default is True.
+    """
+    if hasattr(model, 'module'):
+        model = model.module
+    img = model.show_result(
+        img, result, palette=palette, show=False, opacity=opacity)
+    # plt.figure(figsize=fig_size)
+    # plt.imshow(mmcv.bgr2rgb(img))
+    # plt.title(title)
+    # plt.tight_layout()
+    # plt.show(block=block)
+    return mmcv.bgr2rgb(img)
diff --git a/annotator/uniformer/mmseg/apis/test.py b/annotator/uniformer/mmseg/apis/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e574eb7da04f09a59cf99ff953c36468ae87a326
--- /dev/null
+++ b/annotator/uniformer/mmseg/apis/test.py
@@ -0,0 +1,238 @@
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+from annotator.uniformer.mmcv.image import tensor2imgs
+from annotator.uniformer.mmcv.runner import get_dist_info
+
+
+def np2tmp(array, temp_file_name=None):
+    """Save ndarray to local numpy file.
+
+    Args:
+        array (ndarray): Ndarray to save.
+        temp_file_name (str): Numpy file name. If 'temp_file_name=None', this
+            function will generate a file name with tempfile.NamedTemporaryFile
+            to save ndarray. Default: None.
+
+    Returns:
+        str: The numpy file name.
+    """
+
+    if temp_file_name is None:
+        temp_file_name = tempfile.NamedTemporaryFile(
+            suffix='.npy', delete=False).name
+    np.save(temp_file_name, array)
+    return temp_file_name
+
+
+def single_gpu_test(model,
+                    data_loader,
+                    show=False,
+                    out_dir=None,
+                    efficient_test=False,
+                    opacity=0.5):
+    """Test with single GPU.
+
+    Args:
+        model (nn.Module): Model to be tested.
+        data_loader (utils.data.Dataloader): Pytorch data loader.
+        show (bool): Whether show results during inference. Default: False.
+        out_dir (str, optional): If specified, the results will be dumped into
+            the directory to save output results.
+        efficient_test (bool): Whether save the results as local numpy files to
+            save CPU memory during evaluation. Default: False.
+        opacity(float): Opacity of painted segmentation map.
+            Default 0.5.
+            Must be in (0, 1] range.
+    Returns:
+        list: The prediction results.
+    """
+
+    model.eval()
+    results = []
+    dataset = data_loader.dataset
+    prog_bar = mmcv.ProgressBar(len(dataset))
+    for i, data in enumerate(data_loader):
+        with torch.no_grad():
+            result = model(return_loss=False, **data)
+
+        if show or out_dir:
+            img_tensor = data['img'][0]
+            img_metas = data['img_metas'][0].data[0]
+            imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+            assert len(imgs) == len(img_metas)
+
+            for img, img_meta in zip(imgs, img_metas):
+                h, w, _ = img_meta['img_shape']
+                img_show = img[:h, :w, :]
+
+                ori_h, ori_w = img_meta['ori_shape'][:-1]
+                img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+                if out_dir:
+                    out_file = osp.join(out_dir, img_meta['ori_filename'])
+                else:
+                    out_file = None
+
+                model.module.show_result(
+                    img_show,
+                    result,
+                    palette=dataset.PALETTE,
+                    show=show,
+                    out_file=out_file,
+                    opacity=opacity)
+
+        if isinstance(result, list):
+            if efficient_test:
+                result = [np2tmp(_) for _ in result]
+            results.extend(result)
+        else:
+            if efficient_test:
+                result = np2tmp(result)
+            results.append(result)
+
+        batch_size = len(result)
+        for _ in range(batch_size):
+            prog_bar.update()
+    return results
+
+
+def multi_gpu_test(model,
+                   data_loader,
+                   tmpdir=None,
+                   gpu_collect=False,
+                   efficient_test=False):
+    """Test model with multiple gpus.
+
+    This method tests model with multiple gpus and collects the results
+    under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
+    it encodes results to gpu tensors and use gpu communication for results
+    collection. On cpu mode it saves the results on different gpus to 'tmpdir'
+    and collects them by the rank 0 worker.
+
+    Args:
+        model (nn.Module): Model to be tested.
+        data_loader (utils.data.Dataloader): Pytorch data loader.
+        tmpdir (str): Path of directory to save the temporary results from
+            different gpus under cpu mode.
+        gpu_collect (bool): Option to use either gpu or cpu to collect results.
+        efficient_test (bool): Whether save the results as local numpy files to
+            save CPU memory during evaluation. Default: False.
+
+    Returns:
+        list: The prediction results.
+    """
+
+    model.eval()
+    results = []
+    dataset = data_loader.dataset
+    rank, world_size = get_dist_info()
+    if rank == 0:
+        prog_bar = mmcv.ProgressBar(len(dataset))
+    for i, data in enumerate(data_loader):
+        with torch.no_grad():
+            result = model(return_loss=False, rescale=True, **data)
+
+        if isinstance(result, list):
+            if efficient_test:
+                result = [np2tmp(_) for _ in result]
+            results.extend(result)
+        else:
+            if efficient_test:
+                result = np2tmp(result)
+            results.append(result)
+
+        if rank == 0:
+            batch_size = data['img'][0].size(0)
+            for _ in range(batch_size * world_size):
+                prog_bar.update()
+
+    # collect results from all ranks
+    if gpu_collect:
+        results = collect_results_gpu(results, len(dataset))
+    else:
+        results = collect_results_cpu(results, len(dataset), tmpdir)
+    return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+    """Collect results with CPU."""
+    rank, world_size = get_dist_info()
+    # create a tmp dir if it is not specified
+    if tmpdir is None:
+        MAX_LEN = 512
+        # 32 is whitespace
+        dir_tensor = torch.full((MAX_LEN, ),
+                                32,
+                                dtype=torch.uint8,
+                                device='cuda')
+        if rank == 0:
+            tmpdir = tempfile.mkdtemp()
+            tmpdir = torch.tensor(
+                bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+            dir_tensor[:len(tmpdir)] = tmpdir
+        dist.broadcast(dir_tensor, 0)
+        tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+    else:
+        mmcv.mkdir_or_exist(tmpdir)
+    # dump the part result to the dir
+    mmcv.dump(result_part, osp.join(tmpdir, 'part_{}.pkl'.format(rank)))
+    dist.barrier()
+    # collect all parts
+    if rank != 0:
+        return None
+    else:
+        # load results of all parts from tmp dir
+        part_list = []
+        for i in range(world_size):
+            part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i))
+            part_list.append(mmcv.load(part_file))
+        # sort the results
+        ordered_results = []
+        for res in zip(*part_list):
+            ordered_results.extend(list(res))
+        # the dataloader may pad some samples
+        ordered_results = ordered_results[:size]
+        # remove tmp dir
+        shutil.rmtree(tmpdir)
+        return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+    """Collect results with GPU."""
+    rank, world_size = get_dist_info()
+    # dump result part to tensor with pickle
+    part_tensor = torch.tensor(
+        bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+    # gather all result part tensor shape
+    shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+    shape_list = [shape_tensor.clone() for _ in range(world_size)]
+    dist.all_gather(shape_list, shape_tensor)
+    # padding result part tensor to max length
+    shape_max = torch.tensor(shape_list).max()
+    part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+    part_send[:shape_tensor[0]] = part_tensor
+    part_recv_list = [
+        part_tensor.new_zeros(shape_max) for _ in range(world_size)
+    ]
+    # gather all result part
+    dist.all_gather(part_recv_list, part_send)
+
+    if rank == 0:
+        part_list = []
+        for recv, shape in zip(part_recv_list, shape_list):
+            part_list.append(
+                pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
+        # sort the results
+        ordered_results = []
+        for res in zip(*part_list):
+            ordered_results.extend(list(res))
+        # the dataloader may pad some samples
+        ordered_results = ordered_results[:size]
+        return ordered_results
diff --git a/annotator/uniformer/mmseg/apis/train.py b/annotator/uniformer/mmseg/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f319a919ff023931a6a663e668f27dd1a07a2e
--- /dev/null
+++ b/annotator/uniformer/mmseg/apis/train.py
@@ -0,0 +1,116 @@
+import random
+import warnings
+
+import numpy as np
+import torch
+from annotator.uniformer.mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from annotator.uniformer.mmcv.runner import build_optimizer, build_runner
+
+from annotator.uniformer.mmseg.core import DistEvalHook, EvalHook
+from annotator.uniformer.mmseg.datasets import build_dataloader, build_dataset
+from annotator.uniformer.mmseg.utils import get_root_logger
+
+
+def set_random_seed(seed, deterministic=False):
+    """Set random seed.
+
+    Args:
+        seed (int): Seed to be used.
+        deterministic (bool): Whether to set the deterministic option for
+            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+            to True and `torch.backends.cudnn.benchmark` to False.
+            Default: False.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+    if deterministic:
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+
+def train_segmentor(model,
+                    dataset,
+                    cfg,
+                    distributed=False,
+                    validate=False,
+                    timestamp=None,
+                    meta=None):
+    """Launch segmentor training."""
+    logger = get_root_logger(cfg.log_level)
+
+    # prepare data loaders
+    dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+    data_loaders = [
+        build_dataloader(
+            ds,
+            cfg.data.samples_per_gpu,
+            cfg.data.workers_per_gpu,
+            # cfg.gpus will be ignored if distributed
+            len(cfg.gpu_ids),
+            dist=distributed,
+            seed=cfg.seed,
+            drop_last=True) for ds in dataset
+    ]
+
+    # put model on gpus
+    if distributed:
+        find_unused_parameters = cfg.get('find_unused_parameters', False)
+        # Sets the `find_unused_parameters` parameter in
+        # torch.nn.parallel.DistributedDataParallel
+        model = MMDistributedDataParallel(
+            model.cuda(),
+            device_ids=[torch.cuda.current_device()],
+            broadcast_buffers=False,
+            find_unused_parameters=find_unused_parameters)
+    else:
+        model = MMDataParallel(
+            model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
+
+    # build runner
+    optimizer = build_optimizer(model, cfg.optimizer)
+
+    if cfg.get('runner') is None:
+        cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
+        warnings.warn(
+            'config is now expected to have a `runner` section, '
+            'please set `runner` in your config.', UserWarning)
+
+    runner = build_runner(
+        cfg.runner,
+        default_args=dict(
+            model=model,
+            batch_processor=None,
+            optimizer=optimizer,
+            work_dir=cfg.work_dir,
+            logger=logger,
+            meta=meta))
+
+    # register hooks
+    runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
+                                   cfg.checkpoint_config, cfg.log_config,
+                                   cfg.get('momentum_config', None))
+
+    # an ugly walkaround to make the .log and .log.json filenames the same
+    runner.timestamp = timestamp
+
+    # register eval hooks
+    if validate:
+        val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+        val_dataloader = build_dataloader(
+            val_dataset,
+            samples_per_gpu=1,
+            workers_per_gpu=cfg.data.workers_per_gpu,
+            dist=distributed,
+            shuffle=False)
+        eval_cfg = cfg.get('evaluation', {})
+        eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+        eval_hook = DistEvalHook if distributed else EvalHook
+        runner.register_hook(eval_hook(val_dataloader, **eval_cfg), priority='LOW')
+
+    if cfg.resume_from:
+        runner.resume(cfg.resume_from)
+    elif cfg.load_from:
+        runner.load_checkpoint(cfg.load_from)
+    runner.run(data_loaders, cfg.workflow)
diff --git a/annotator/uniformer/mmseg/core/__init__.py b/annotator/uniformer/mmseg/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..965605587211b7bf0bd6bc3acdbb33dd49cab023
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/__init__.py
@@ -0,0 +1,3 @@
+from .evaluation import *  # noqa: F401, F403
+from .seg import *  # noqa: F401, F403
+from .utils import *  # noqa: F401, F403
diff --git a/annotator/uniformer/mmseg/core/evaluation/__init__.py b/annotator/uniformer/mmseg/core/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7cc4b23413a0639e9de00eeb0bf600632d2c6cd
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/evaluation/__init__.py
@@ -0,0 +1,8 @@
+from .class_names import get_classes, get_palette
+from .eval_hooks import DistEvalHook, EvalHook
+from .metrics import eval_metrics, mean_dice, mean_fscore, mean_iou
+
+__all__ = [
+    'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'mean_fscore',
+    'eval_metrics', 'get_classes', 'get_palette'
+]
diff --git a/annotator/uniformer/mmseg/core/evaluation/class_names.py b/annotator/uniformer/mmseg/core/evaluation/class_names.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffae816cf980ce4b03e491cc0c4298cb823797e6
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/evaluation/class_names.py
@@ -0,0 +1,152 @@
+import annotator.uniformer.mmcv as mmcv
+
+
+def cityscapes_classes():
+    """Cityscapes class names for external use."""
+    return [
+        'road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+        'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+        'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+        'bicycle'
+    ]
+
+
+def ade_classes():
+    """ADE20K class names for external use."""
+    return [
+        'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+        'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+        'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+        'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+        'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+        'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+        'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+        'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+        'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+        'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+        'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+        'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+        'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+        'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+        'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+        'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+        'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+        'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+        'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+        'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+        'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+        'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+        'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+        'clock', 'flag'
+    ]
+
+
+def voc_classes():
+    """Pascal VOC class names for external use."""
+    return [
+        'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
+        'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
+        'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
+        'tvmonitor'
+    ]
+
+
+def cityscapes_palette():
+    """Cityscapes palette for external use."""
+    return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+            [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+            [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+            [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
+            [0, 0, 230], [119, 11, 32]]
+
+
+def ade_palette():
+    """ADE20K palette for external use."""
+    return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+            [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+            [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+            [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+            [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+            [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+            [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+            [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+            [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+            [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+            [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+            [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+            [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+            [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+            [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+            [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+            [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+            [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+            [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+            [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+            [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+            [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+            [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+            [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+            [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+            [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+            [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+            [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+            [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+            [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+            [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+            [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+            [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+            [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+            [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+            [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+            [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+            [102, 255, 0], [92, 0, 255]]
+
+
+def voc_palette():
+    """Pascal VOC palette for external use."""
+    return [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+            [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+            [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+            [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+            [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+
+
+dataset_aliases = {
+    'cityscapes': ['cityscapes'],
+    'ade': ['ade', 'ade20k'],
+    'voc': ['voc', 'pascal_voc', 'voc12', 'voc12aug']
+}
+
+
+def get_classes(dataset):
+    """Get class names of a dataset."""
+    alias2name = {}
+    for name, aliases in dataset_aliases.items():
+        for alias in aliases:
+            alias2name[alias] = name
+
+    if mmcv.is_str(dataset):
+        if dataset in alias2name:
+            labels = eval(alias2name[dataset] + '_classes()')
+        else:
+            raise ValueError(f'Unrecognized dataset: {dataset}')
+    else:
+        raise TypeError(f'dataset must a str, but got {type(dataset)}')
+    return labels
+
+
+def get_palette(dataset):
+    """Get class palette (RGB) of a dataset."""
+    alias2name = {}
+    for name, aliases in dataset_aliases.items():
+        for alias in aliases:
+            alias2name[alias] = name
+
+    if mmcv.is_str(dataset):
+        if dataset in alias2name:
+            labels = eval(alias2name[dataset] + '_palette()')
+        else:
+            raise ValueError(f'Unrecognized dataset: {dataset}')
+    else:
+        raise TypeError(f'dataset must a str, but got {type(dataset)}')
+    return labels
diff --git a/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py b/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc100c8f96e817a6ed2666f7c9f762af2463b48
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/evaluation/eval_hooks.py
@@ -0,0 +1,109 @@
+import os.path as osp
+
+from annotator.uniformer.mmcv.runner import DistEvalHook as _DistEvalHook
+from annotator.uniformer.mmcv.runner import EvalHook as _EvalHook
+
+
+class EvalHook(_EvalHook):
+    """Single GPU EvalHook, with efficient test support.
+
+    Args:
+        by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+            If set to True, it will perform by epoch. Otherwise, by iteration.
+            Default: False.
+        efficient_test (bool): Whether save the results as local numpy files to
+            save CPU memory during evaluation. Default: False.
+    Returns:
+        list: The prediction results.
+    """
+
+    greater_keys = ['mIoU', 'mAcc', 'aAcc']
+
+    def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+        super().__init__(*args, by_epoch=by_epoch, **kwargs)
+        self.efficient_test = efficient_test
+
+    def after_train_iter(self, runner):
+        """After train epoch hook.
+
+        Override default ``single_gpu_test``.
+        """
+        if self.by_epoch or not self.every_n_iters(runner, self.interval):
+            return
+        from annotator.uniformer.mmseg.apis import single_gpu_test
+        runner.log_buffer.clear()
+        results = single_gpu_test(
+            runner.model,
+            self.dataloader,
+            show=False,
+            efficient_test=self.efficient_test)
+        self.evaluate(runner, results)
+
+    def after_train_epoch(self, runner):
+        """After train epoch hook.
+
+        Override default ``single_gpu_test``.
+        """
+        if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+            return
+        from annotator.uniformer.mmseg.apis import single_gpu_test
+        runner.log_buffer.clear()
+        results = single_gpu_test(runner.model, self.dataloader, show=False)
+        self.evaluate(runner, results)
+
+
+class DistEvalHook(_DistEvalHook):
+    """Distributed EvalHook, with efficient test support.
+
+    Args:
+        by_epoch (bool): Determine perform evaluation by epoch or by iteration.
+            If set to True, it will perform by epoch. Otherwise, by iteration.
+            Default: False.
+        efficient_test (bool): Whether save the results as local numpy files to
+            save CPU memory during evaluation. Default: False.
+    Returns:
+        list: The prediction results.
+    """
+
+    greater_keys = ['mIoU', 'mAcc', 'aAcc']
+
+    def __init__(self, *args, by_epoch=False, efficient_test=False, **kwargs):
+        super().__init__(*args, by_epoch=by_epoch, **kwargs)
+        self.efficient_test = efficient_test
+
+    def after_train_iter(self, runner):
+        """After train epoch hook.
+
+        Override default ``multi_gpu_test``.
+        """
+        if self.by_epoch or not self.every_n_iters(runner, self.interval):
+            return
+        from annotator.uniformer.mmseg.apis import multi_gpu_test
+        runner.log_buffer.clear()
+        results = multi_gpu_test(
+            runner.model,
+            self.dataloader,
+            tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+            gpu_collect=self.gpu_collect,
+            efficient_test=self.efficient_test)
+        if runner.rank == 0:
+            print('\n')
+            self.evaluate(runner, results)
+
+    def after_train_epoch(self, runner):
+        """After train epoch hook.
+
+        Override default ``multi_gpu_test``.
+        """
+        if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
+            return
+        from annotator.uniformer.mmseg.apis import multi_gpu_test
+        runner.log_buffer.clear()
+        results = multi_gpu_test(
+            runner.model,
+            self.dataloader,
+            tmpdir=osp.join(runner.work_dir, '.eval_hook'),
+            gpu_collect=self.gpu_collect)
+        if runner.rank == 0:
+            print('\n')
+            self.evaluate(runner, results)
diff --git a/annotator/uniformer/mmseg/core/evaluation/metrics.py b/annotator/uniformer/mmseg/core/evaluation/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..16c7dd47cadd53cf1caaa194e28a343f2aacc599
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/evaluation/metrics.py
@@ -0,0 +1,326 @@
+from collections import OrderedDict
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+
+
+def f_score(precision, recall, beta=1):
+    """calcuate the f-score value.
+
+    Args:
+        precision (float | torch.Tensor): The precision value.
+        recall (float | torch.Tensor): The recall value.
+        beta (int): Determines the weight of recall in the combined score.
+            Default: False.
+
+    Returns:
+        [torch.tensor]: The f-score value.
+    """
+    score = (1 + beta**2) * (precision * recall) / (
+        (beta**2 * precision) + recall)
+    return score
+
+
+def intersect_and_union(pred_label,
+                        label,
+                        num_classes,
+                        ignore_index,
+                        label_map=dict(),
+                        reduce_zero_label=False):
+    """Calculate intersection and Union.
+
+    Args:
+        pred_label (ndarray | str): Prediction segmentation map
+            or predict result filename.
+        label (ndarray | str): Ground truth segmentation map
+            or label filename.
+        num_classes (int): Number of categories.
+        ignore_index (int): Index that will be ignored in evaluation.
+        label_map (dict): Mapping old labels to new labels. The parameter will
+            work only when label is str. Default: dict().
+        reduce_zero_label (bool): Wether ignore zero label. The parameter will
+            work only when label is str. Default: False.
+
+     Returns:
+         torch.Tensor: The intersection of prediction and ground truth
+            histogram on all classes.
+         torch.Tensor: The union of prediction and ground truth histogram on
+            all classes.
+         torch.Tensor: The prediction histogram on all classes.
+         torch.Tensor: The ground truth histogram on all classes.
+    """
+
+    if isinstance(pred_label, str):
+        pred_label = torch.from_numpy(np.load(pred_label))
+    else:
+        pred_label = torch.from_numpy((pred_label))
+
+    if isinstance(label, str):
+        label = torch.from_numpy(
+            mmcv.imread(label, flag='unchanged', backend='pillow'))
+    else:
+        label = torch.from_numpy(label)
+
+    if label_map is not None:
+        for old_id, new_id in label_map.items():
+            label[label == old_id] = new_id
+    if reduce_zero_label:
+        label[label == 0] = 255
+        label = label - 1
+        label[label == 254] = 255
+
+    mask = (label != ignore_index)
+    pred_label = pred_label[mask]
+    label = label[mask]
+
+    intersect = pred_label[pred_label == label]
+    area_intersect = torch.histc(
+        intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
+    area_pred_label = torch.histc(
+        pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+    area_label = torch.histc(
+        label.float(), bins=(num_classes), min=0, max=num_classes - 1)
+    area_union = area_pred_label + area_label - area_intersect
+    return area_intersect, area_union, area_pred_label, area_label
+
+
+def total_intersect_and_union(results,
+                              gt_seg_maps,
+                              num_classes,
+                              ignore_index,
+                              label_map=dict(),
+                              reduce_zero_label=False):
+    """Calculate Total Intersection and Union.
+
+    Args:
+        results (list[ndarray] | list[str]): List of prediction segmentation
+            maps or list of prediction result filenames.
+        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+            segmentation maps or list of label filenames.
+        num_classes (int): Number of categories.
+        ignore_index (int): Index that will be ignored in evaluation.
+        label_map (dict): Mapping old labels to new labels. Default: dict().
+        reduce_zero_label (bool): Wether ignore zero label. Default: False.
+
+     Returns:
+         ndarray: The intersection of prediction and ground truth histogram
+             on all classes.
+         ndarray: The union of prediction and ground truth histogram on all
+             classes.
+         ndarray: The prediction histogram on all classes.
+         ndarray: The ground truth histogram on all classes.
+    """
+    num_imgs = len(results)
+    assert len(gt_seg_maps) == num_imgs
+    total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64)
+    total_area_union = torch.zeros((num_classes, ), dtype=torch.float64)
+    total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64)
+    total_area_label = torch.zeros((num_classes, ), dtype=torch.float64)
+    for i in range(num_imgs):
+        area_intersect, area_union, area_pred_label, area_label = \
+            intersect_and_union(
+                results[i], gt_seg_maps[i], num_classes, ignore_index,
+                label_map, reduce_zero_label)
+        total_area_intersect += area_intersect
+        total_area_union += area_union
+        total_area_pred_label += area_pred_label
+        total_area_label += area_label
+    return total_area_intersect, total_area_union, total_area_pred_label, \
+        total_area_label
+
+
+def mean_iou(results,
+             gt_seg_maps,
+             num_classes,
+             ignore_index,
+             nan_to_num=None,
+             label_map=dict(),
+             reduce_zero_label=False):
+    """Calculate Mean Intersection and Union (mIoU)
+
+    Args:
+        results (list[ndarray] | list[str]): List of prediction segmentation
+            maps or list of prediction result filenames.
+        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+            segmentation maps or list of label filenames.
+        num_classes (int): Number of categories.
+        ignore_index (int): Index that will be ignored in evaluation.
+        nan_to_num (int, optional): If specified, NaN values will be replaced
+            by the numbers defined by the user. Default: None.
+        label_map (dict): Mapping old labels to new labels. Default: dict().
+        reduce_zero_label (bool): Wether ignore zero label. Default: False.
+
+     Returns:
+        dict[str, float | ndarray]:
+            <aAcc> float: Overall accuracy on all images.
+            <Acc> ndarray: Per category accuracy, shape (num_classes, ).
+            <IoU> ndarray: Per category IoU, shape (num_classes, ).
+    """
+    iou_result = eval_metrics(
+        results=results,
+        gt_seg_maps=gt_seg_maps,
+        num_classes=num_classes,
+        ignore_index=ignore_index,
+        metrics=['mIoU'],
+        nan_to_num=nan_to_num,
+        label_map=label_map,
+        reduce_zero_label=reduce_zero_label)
+    return iou_result
+
+
+def mean_dice(results,
+              gt_seg_maps,
+              num_classes,
+              ignore_index,
+              nan_to_num=None,
+              label_map=dict(),
+              reduce_zero_label=False):
+    """Calculate Mean Dice (mDice)
+
+    Args:
+        results (list[ndarray] | list[str]): List of prediction segmentation
+            maps or list of prediction result filenames.
+        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+            segmentation maps or list of label filenames.
+        num_classes (int): Number of categories.
+        ignore_index (int): Index that will be ignored in evaluation.
+        nan_to_num (int, optional): If specified, NaN values will be replaced
+            by the numbers defined by the user. Default: None.
+        label_map (dict): Mapping old labels to new labels. Default: dict().
+        reduce_zero_label (bool): Wether ignore zero label. Default: False.
+
+     Returns:
+        dict[str, float | ndarray]: Default metrics.
+            <aAcc> float: Overall accuracy on all images.
+            <Acc> ndarray: Per category accuracy, shape (num_classes, ).
+            <Dice> ndarray: Per category dice, shape (num_classes, ).
+    """
+
+    dice_result = eval_metrics(
+        results=results,
+        gt_seg_maps=gt_seg_maps,
+        num_classes=num_classes,
+        ignore_index=ignore_index,
+        metrics=['mDice'],
+        nan_to_num=nan_to_num,
+        label_map=label_map,
+        reduce_zero_label=reduce_zero_label)
+    return dice_result
+
+
+def mean_fscore(results,
+                gt_seg_maps,
+                num_classes,
+                ignore_index,
+                nan_to_num=None,
+                label_map=dict(),
+                reduce_zero_label=False,
+                beta=1):
+    """Calculate Mean Intersection and Union (mIoU)
+
+    Args:
+        results (list[ndarray] | list[str]): List of prediction segmentation
+            maps or list of prediction result filenames.
+        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+            segmentation maps or list of label filenames.
+        num_classes (int): Number of categories.
+        ignore_index (int): Index that will be ignored in evaluation.
+        nan_to_num (int, optional): If specified, NaN values will be replaced
+            by the numbers defined by the user. Default: None.
+        label_map (dict): Mapping old labels to new labels. Default: dict().
+        reduce_zero_label (bool): Wether ignore zero label. Default: False.
+        beta (int): Determines the weight of recall in the combined score.
+            Default: False.
+
+
+     Returns:
+        dict[str, float | ndarray]: Default metrics.
+            <aAcc> float: Overall accuracy on all images.
+            <Fscore> ndarray: Per category recall, shape (num_classes, ).
+            <Precision> ndarray: Per category precision, shape (num_classes, ).
+            <Recall> ndarray: Per category f-score, shape (num_classes, ).
+    """
+    fscore_result = eval_metrics(
+        results=results,
+        gt_seg_maps=gt_seg_maps,
+        num_classes=num_classes,
+        ignore_index=ignore_index,
+        metrics=['mFscore'],
+        nan_to_num=nan_to_num,
+        label_map=label_map,
+        reduce_zero_label=reduce_zero_label,
+        beta=beta)
+    return fscore_result
+
+
+def eval_metrics(results,
+                 gt_seg_maps,
+                 num_classes,
+                 ignore_index,
+                 metrics=['mIoU'],
+                 nan_to_num=None,
+                 label_map=dict(),
+                 reduce_zero_label=False,
+                 beta=1):
+    """Calculate evaluation metrics
+    Args:
+        results (list[ndarray] | list[str]): List of prediction segmentation
+            maps or list of prediction result filenames.
+        gt_seg_maps (list[ndarray] | list[str]): list of ground truth
+            segmentation maps or list of label filenames.
+        num_classes (int): Number of categories.
+        ignore_index (int): Index that will be ignored in evaluation.
+        metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
+        nan_to_num (int, optional): If specified, NaN values will be replaced
+            by the numbers defined by the user. Default: None.
+        label_map (dict): Mapping old labels to new labels. Default: dict().
+        reduce_zero_label (bool): Wether ignore zero label. Default: False.
+     Returns:
+        float: Overall accuracy on all images.
+        ndarray: Per category accuracy, shape (num_classes, ).
+        ndarray: Per category evaluation metrics, shape (num_classes, ).
+    """
+    if isinstance(metrics, str):
+        metrics = [metrics]
+    allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+    if not set(metrics).issubset(set(allowed_metrics)):
+        raise KeyError('metrics {} is not supported'.format(metrics))
+
+    total_area_intersect, total_area_union, total_area_pred_label, \
+        total_area_label = total_intersect_and_union(
+            results, gt_seg_maps, num_classes, ignore_index, label_map,
+            reduce_zero_label)
+    all_acc = total_area_intersect.sum() / total_area_label.sum()
+    ret_metrics = OrderedDict({'aAcc': all_acc})
+    for metric in metrics:
+        if metric == 'mIoU':
+            iou = total_area_intersect / total_area_union
+            acc = total_area_intersect / total_area_label
+            ret_metrics['IoU'] = iou
+            ret_metrics['Acc'] = acc
+        elif metric == 'mDice':
+            dice = 2 * total_area_intersect / (
+                total_area_pred_label + total_area_label)
+            acc = total_area_intersect / total_area_label
+            ret_metrics['Dice'] = dice
+            ret_metrics['Acc'] = acc
+        elif metric == 'mFscore':
+            precision = total_area_intersect / total_area_pred_label
+            recall = total_area_intersect / total_area_label
+            f_value = torch.tensor(
+                [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
+            ret_metrics['Fscore'] = f_value
+            ret_metrics['Precision'] = precision
+            ret_metrics['Recall'] = recall
+
+    ret_metrics = {
+        metric: value.numpy()
+        for metric, value in ret_metrics.items()
+    }
+    if nan_to_num is not None:
+        ret_metrics = OrderedDict({
+            metric: np.nan_to_num(metric_value, nan=nan_to_num)
+            for metric, metric_value in ret_metrics.items()
+        })
+    return ret_metrics
diff --git a/annotator/uniformer/mmseg/core/seg/__init__.py b/annotator/uniformer/mmseg/core/seg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..93bc129b685e4a3efca2cc891729981b2865900d
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/__init__.py
@@ -0,0 +1,4 @@
+from .builder import build_pixel_sampler
+from .sampler import BasePixelSampler, OHEMPixelSampler
+
+__all__ = ['build_pixel_sampler', 'BasePixelSampler', 'OHEMPixelSampler']
diff --git a/annotator/uniformer/mmseg/core/seg/builder.py b/annotator/uniformer/mmseg/core/seg/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..db61f03d4abb2072f2532ce4429c0842495e015b
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/builder.py
@@ -0,0 +1,8 @@
+from annotator.uniformer.mmcv.utils import Registry, build_from_cfg
+
+PIXEL_SAMPLERS = Registry('pixel sampler')
+
+
+def build_pixel_sampler(cfg, **default_args):
+    """Build pixel sampler for segmentation map."""
+    return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args)
diff --git a/annotator/uniformer/mmseg/core/seg/sampler/__init__.py b/annotator/uniformer/mmseg/core/seg/sampler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..332b242c03d1c5e80d4577df442a9a037b1816e1
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/sampler/__init__.py
@@ -0,0 +1,4 @@
+from .base_pixel_sampler import BasePixelSampler
+from .ohem_pixel_sampler import OHEMPixelSampler
+
+__all__ = ['BasePixelSampler', 'OHEMPixelSampler']
diff --git a/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py b/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b75b1566c9f18169cee51d4b55d75e0357b69c57
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/sampler/base_pixel_sampler.py
@@ -0,0 +1,12 @@
+from abc import ABCMeta, abstractmethod
+
+
+class BasePixelSampler(metaclass=ABCMeta):
+    """Base class of pixel sampler."""
+
+    def __init__(self, **kwargs):
+        pass
+
+    @abstractmethod
+    def sample(self, seg_logit, seg_label):
+        """Placeholder for sample function."""
diff --git a/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py b/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..88bb10d44026ba9f21756eaea9e550841cd59b9f
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/seg/sampler/ohem_pixel_sampler.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn.functional as F
+
+from ..builder import PIXEL_SAMPLERS
+from .base_pixel_sampler import BasePixelSampler
+
+
+@PIXEL_SAMPLERS.register_module()
+class OHEMPixelSampler(BasePixelSampler):
+    """Online Hard Example Mining Sampler for segmentation.
+
+    Args:
+        context (nn.Module): The context of sampler, subclass of
+            :obj:`BaseDecodeHead`.
+        thresh (float, optional): The threshold for hard example selection.
+            Below which, are prediction with low confidence. If not
+            specified, the hard examples will be pixels of top ``min_kept``
+            loss. Default: None.
+        min_kept (int, optional): The minimum number of predictions to keep.
+            Default: 100000.
+    """
+
+    def __init__(self, context, thresh=None, min_kept=100000):
+        super(OHEMPixelSampler, self).__init__()
+        self.context = context
+        assert min_kept > 1
+        self.thresh = thresh
+        self.min_kept = min_kept
+
+    def sample(self, seg_logit, seg_label):
+        """Sample pixels that have high loss or with low prediction confidence.
+
+        Args:
+            seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
+            seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
+
+        Returns:
+            torch.Tensor: segmentation weight, shape (N, H, W)
+        """
+        with torch.no_grad():
+            assert seg_logit.shape[2:] == seg_label.shape[2:]
+            assert seg_label.shape[1] == 1
+            seg_label = seg_label.squeeze(1).long()
+            batch_kept = self.min_kept * seg_label.size(0)
+            valid_mask = seg_label != self.context.ignore_index
+            seg_weight = seg_logit.new_zeros(size=seg_label.size())
+            valid_seg_weight = seg_weight[valid_mask]
+            if self.thresh is not None:
+                seg_prob = F.softmax(seg_logit, dim=1)
+
+                tmp_seg_label = seg_label.clone().unsqueeze(1)
+                tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
+                seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
+                sort_prob, sort_indices = seg_prob[valid_mask].sort()
+
+                if sort_prob.numel() > 0:
+                    min_threshold = sort_prob[min(batch_kept,
+                                                  sort_prob.numel() - 1)]
+                else:
+                    min_threshold = 0.0
+                threshold = max(min_threshold, self.thresh)
+                valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
+            else:
+                losses = self.context.loss_decode(
+                    seg_logit,
+                    seg_label,
+                    weight=None,
+                    ignore_index=self.context.ignore_index,
+                    reduction_override='none')
+                # faster than topk according to https://github.com/pytorch/pytorch/issues/22812  # noqa
+                _, sort_indices = losses[valid_mask].sort(descending=True)
+                valid_seg_weight[sort_indices[:batch_kept]] = 1.
+
+            seg_weight[valid_mask] = valid_seg_weight
+
+            return seg_weight
diff --git a/annotator/uniformer/mmseg/core/utils/__init__.py b/annotator/uniformer/mmseg/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2678b321c295bcceaef945111ac3524be19d6e4
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/utils/__init__.py
@@ -0,0 +1,3 @@
+from .misc import add_prefix
+
+__all__ = ['add_prefix']
diff --git a/annotator/uniformer/mmseg/core/utils/misc.py b/annotator/uniformer/mmseg/core/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb862a82bd47c8624db3dd5c6fb6ad8a03b62466
--- /dev/null
+++ b/annotator/uniformer/mmseg/core/utils/misc.py
@@ -0,0 +1,17 @@
+def add_prefix(inputs, prefix):
+    """Add prefix for dict.
+
+    Args:
+        inputs (dict): The input dict with str keys.
+        prefix (str): The prefix to add.
+
+    Returns:
+
+        dict: The dict with keys updated with ``prefix``.
+    """
+
+    outputs = dict()
+    for name, value in inputs.items():
+        outputs[f'{prefix}.{name}'] = value
+
+    return outputs
diff --git a/annotator/uniformer/mmseg/datasets/__init__.py b/annotator/uniformer/mmseg/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebeaef4a28ef655e43578552a8aef6b77f13a636
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/__init__.py
@@ -0,0 +1,19 @@
+from .ade import ADE20KDataset
+from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
+from .chase_db1 import ChaseDB1Dataset
+from .cityscapes import CityscapesDataset
+from .custom import CustomDataset
+from .dataset_wrappers import ConcatDataset, RepeatDataset
+from .drive import DRIVEDataset
+from .hrf import HRFDataset
+from .pascal_context import PascalContextDataset, PascalContextDataset59
+from .stare import STAREDataset
+from .voc import PascalVOCDataset
+
+__all__ = [
+    'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
+    'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
+    'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
+    'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
+    'STAREDataset'
+]
diff --git a/annotator/uniformer/mmseg/datasets/ade.py b/annotator/uniformer/mmseg/datasets/ade.py
new file mode 100644
index 0000000000000000000000000000000000000000..5913e43775ed4920b6934c855eb5a37c54218ebf
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/ade.py
@@ -0,0 +1,84 @@
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class ADE20KDataset(CustomDataset):
+    """ADE20K dataset.
+
+    In segmentation map annotation for ADE20K, 0 stands for background, which
+    is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
+    The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
+    '.png'.
+    """
+    CLASSES = (
+        'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
+        'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
+        'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
+        'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
+        'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
+        'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
+        'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
+        'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
+        'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
+        'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
+        'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
+        'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
+        'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
+        'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
+        'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
+        'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
+        'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
+        'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
+        'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
+        'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
+        'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
+        'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
+        'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
+        'clock', 'flag')
+
+    PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+               [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+               [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+               [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+               [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+               [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+               [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+               [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+               [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+               [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+               [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+               [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+               [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+               [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+               [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
+               [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
+               [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
+               [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
+               [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
+               [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
+               [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
+               [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
+               [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
+               [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
+               [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
+               [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
+               [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
+               [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
+               [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
+               [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
+               [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
+               [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
+               [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
+               [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
+               [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
+               [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
+               [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
+               [102, 255, 0], [92, 0, 255]]
+
+    def __init__(self, **kwargs):
+        super(ADE20KDataset, self).__init__(
+            img_suffix='.jpg',
+            seg_map_suffix='.png',
+            reduce_zero_label=True,
+            **kwargs)
diff --git a/annotator/uniformer/mmseg/datasets/builder.py b/annotator/uniformer/mmseg/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0798b14cd8b39fc58d8f2a4930f1e079b5bf8b55
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/builder.py
@@ -0,0 +1,169 @@
+import copy
+import platform
+import random
+from functools import partial
+
+import numpy as np
+from annotator.uniformer.mmcv.parallel import collate
+from annotator.uniformer.mmcv.runner import get_dist_info
+from annotator.uniformer.mmcv.utils import Registry, build_from_cfg
+from annotator.uniformer.mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
+from torch.utils.data import DistributedSampler
+
+if platform.system() != 'Windows':
+    # https://github.com/pytorch/pytorch/issues/973
+    import resource
+    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+    hard_limit = rlimit[1]
+    soft_limit = min(4096, hard_limit)
+    resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+
+DATASETS = Registry('dataset')
+PIPELINES = Registry('pipeline')
+
+
+def _concat_dataset(cfg, default_args=None):
+    """Build :obj:`ConcatDataset by."""
+    from .dataset_wrappers import ConcatDataset
+    img_dir = cfg['img_dir']
+    ann_dir = cfg.get('ann_dir', None)
+    split = cfg.get('split', None)
+    num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1
+    if ann_dir is not None:
+        num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1
+    else:
+        num_ann_dir = 0
+    if split is not None:
+        num_split = len(split) if isinstance(split, (list, tuple)) else 1
+    else:
+        num_split = 0
+    if num_img_dir > 1:
+        assert num_img_dir == num_ann_dir or num_ann_dir == 0
+        assert num_img_dir == num_split or num_split == 0
+    else:
+        assert num_split == num_ann_dir or num_ann_dir <= 1
+    num_dset = max(num_split, num_img_dir)
+
+    datasets = []
+    for i in range(num_dset):
+        data_cfg = copy.deepcopy(cfg)
+        if isinstance(img_dir, (list, tuple)):
+            data_cfg['img_dir'] = img_dir[i]
+        if isinstance(ann_dir, (list, tuple)):
+            data_cfg['ann_dir'] = ann_dir[i]
+        if isinstance(split, (list, tuple)):
+            data_cfg['split'] = split[i]
+        datasets.append(build_dataset(data_cfg, default_args))
+
+    return ConcatDataset(datasets)
+
+
+def build_dataset(cfg, default_args=None):
+    """Build datasets."""
+    from .dataset_wrappers import ConcatDataset, RepeatDataset
+    if isinstance(cfg, (list, tuple)):
+        dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+    elif cfg['type'] == 'RepeatDataset':
+        dataset = RepeatDataset(
+            build_dataset(cfg['dataset'], default_args), cfg['times'])
+    elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
+            cfg.get('split', None), (list, tuple)):
+        dataset = _concat_dataset(cfg, default_args)
+    else:
+        dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+    return dataset
+
+
+def build_dataloader(dataset,
+                     samples_per_gpu,
+                     workers_per_gpu,
+                     num_gpus=1,
+                     dist=True,
+                     shuffle=True,
+                     seed=None,
+                     drop_last=False,
+                     pin_memory=True,
+                     dataloader_type='PoolDataLoader',
+                     **kwargs):
+    """Build PyTorch DataLoader.
+
+    In distributed training, each GPU/process has a dataloader.
+    In non-distributed training, there is only one dataloader for all GPUs.
+
+    Args:
+        dataset (Dataset): A PyTorch dataset.
+        samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+            batch size of each GPU.
+        workers_per_gpu (int): How many subprocesses to use for data loading
+            for each GPU.
+        num_gpus (int): Number of GPUs. Only used in non-distributed training.
+        dist (bool): Distributed training/test or not. Default: True.
+        shuffle (bool): Whether to shuffle the data at every epoch.
+            Default: True.
+        seed (int | None): Seed to be used. Default: None.
+        drop_last (bool): Whether to drop the last incomplete batch in epoch.
+            Default: False
+        pin_memory (bool): Whether to use pin_memory in DataLoader.
+            Default: True
+        dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
+        kwargs: any keyword argument to be used to initialize DataLoader
+
+    Returns:
+        DataLoader: A PyTorch dataloader.
+    """
+    rank, world_size = get_dist_info()
+    if dist:
+        sampler = DistributedSampler(
+            dataset, world_size, rank, shuffle=shuffle)
+        shuffle = False
+        batch_size = samples_per_gpu
+        num_workers = workers_per_gpu
+    else:
+        sampler = None
+        batch_size = num_gpus * samples_per_gpu
+        num_workers = num_gpus * workers_per_gpu
+
+    init_fn = partial(
+        worker_init_fn, num_workers=num_workers, rank=rank,
+        seed=seed) if seed is not None else None
+
+    assert dataloader_type in (
+        'DataLoader',
+        'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
+
+    if dataloader_type == 'PoolDataLoader':
+        dataloader = PoolDataLoader
+    elif dataloader_type == 'DataLoader':
+        dataloader = DataLoader
+
+    data_loader = dataloader(
+        dataset,
+        batch_size=batch_size,
+        sampler=sampler,
+        num_workers=num_workers,
+        collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
+        pin_memory=pin_memory,
+        shuffle=shuffle,
+        worker_init_fn=init_fn,
+        drop_last=drop_last,
+        **kwargs)
+
+    return data_loader
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+    """Worker init func for dataloader.
+
+    The seed of each worker equals to num_worker * rank + worker_id + user_seed
+
+    Args:
+        worker_id (int): Worker id.
+        num_workers (int): Number of workers.
+        rank (int): The rank of current process.
+        seed (int): The random seed to use.
+    """
+
+    worker_seed = num_workers * rank + worker_id + seed
+    np.random.seed(worker_seed)
+    random.seed(worker_seed)
diff --git a/annotator/uniformer/mmseg/datasets/chase_db1.py b/annotator/uniformer/mmseg/datasets/chase_db1.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc29bea14704a4407f83474610cbc3bef32c708
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/chase_db1.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class ChaseDB1Dataset(CustomDataset):
+    """Chase_db1 dataset.
+
+    In segmentation map annotation for Chase_db1, 0 stands for background,
+    which is included in 2 categories. ``reduce_zero_label`` is fixed to False.
+    The ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+    '_1stHO.png'.
+    """
+
+    CLASSES = ('background', 'vessel')
+
+    PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+    def __init__(self, **kwargs):
+        super(ChaseDB1Dataset, self).__init__(
+            img_suffix='.png',
+            seg_map_suffix='_1stHO.png',
+            reduce_zero_label=False,
+            **kwargs)
+        assert osp.exists(self.img_dir)
diff --git a/annotator/uniformer/mmseg/datasets/cityscapes.py b/annotator/uniformer/mmseg/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..81e47a914a1aa2e5458e18669d65ffb742f46fc6
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/cityscapes.py
@@ -0,0 +1,217 @@
+import os.path as osp
+import tempfile
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from annotator.uniformer.mmcv.utils import print_log
+from PIL import Image
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class CityscapesDataset(CustomDataset):
+    """Cityscapes dataset.
+
+    The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is
+    fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset.
+    """
+
+    CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
+               'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
+               'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+               'bicycle')
+
+    PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
+               [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
+               [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
+               [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
+               [0, 80, 100], [0, 0, 230], [119, 11, 32]]
+
+    def __init__(self, **kwargs):
+        super(CityscapesDataset, self).__init__(
+            img_suffix='_leftImg8bit.png',
+            seg_map_suffix='_gtFine_labelTrainIds.png',
+            **kwargs)
+
+    @staticmethod
+    def _convert_to_label_id(result):
+        """Convert trainId to id for cityscapes."""
+        if isinstance(result, str):
+            result = np.load(result)
+        import cityscapesscripts.helpers.labels as CSLabels
+        result_copy = result.copy()
+        for trainId, label in CSLabels.trainId2label.items():
+            result_copy[result == trainId] = label.id
+
+        return result_copy
+
+    def results2img(self, results, imgfile_prefix, to_label_id):
+        """Write the segmentation results to images.
+
+        Args:
+            results (list[list | tuple | ndarray]): Testing results of the
+                dataset.
+            imgfile_prefix (str): The filename prefix of the png files.
+                If the prefix is "somepath/xxx",
+                the png files will be named "somepath/xxx.png".
+            to_label_id (bool): whether convert output to label_id for
+                submission
+
+        Returns:
+            list[str: str]: result txt files which contains corresponding
+            semantic segmentation images.
+        """
+        mmcv.mkdir_or_exist(imgfile_prefix)
+        result_files = []
+        prog_bar = mmcv.ProgressBar(len(self))
+        for idx in range(len(self)):
+            result = results[idx]
+            if to_label_id:
+                result = self._convert_to_label_id(result)
+            filename = self.img_infos[idx]['filename']
+            basename = osp.splitext(osp.basename(filename))[0]
+
+            png_filename = osp.join(imgfile_prefix, f'{basename}.png')
+
+            output = Image.fromarray(result.astype(np.uint8)).convert('P')
+            import cityscapesscripts.helpers.labels as CSLabels
+            palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8)
+            for label_id, label in CSLabels.id2label.items():
+                palette[label_id] = label.color
+
+            output.putpalette(palette)
+            output.save(png_filename)
+            result_files.append(png_filename)
+            prog_bar.update()
+
+        return result_files
+
+    def format_results(self, results, imgfile_prefix=None, to_label_id=True):
+        """Format the results into dir (standard format for Cityscapes
+        evaluation).
+
+        Args:
+            results (list): Testing results of the dataset.
+            imgfile_prefix (str | None): The prefix of images files. It
+                includes the file path and the prefix of filename, e.g.,
+                "a/b/prefix". If not specified, a temp file will be created.
+                Default: None.
+            to_label_id (bool): whether convert output to label_id for
+                submission. Default: False
+
+        Returns:
+            tuple: (result_files, tmp_dir), result_files is a list containing
+                the image paths, tmp_dir is the temporal directory created
+                for saving json/png files when img_prefix is not specified.
+        """
+
+        assert isinstance(results, list), 'results must be a list'
+        assert len(results) == len(self), (
+            'The length of results is not equal to the dataset len: '
+            f'{len(results)} != {len(self)}')
+
+        if imgfile_prefix is None:
+            tmp_dir = tempfile.TemporaryDirectory()
+            imgfile_prefix = tmp_dir.name
+        else:
+            tmp_dir = None
+        result_files = self.results2img(results, imgfile_prefix, to_label_id)
+
+        return result_files, tmp_dir
+
+    def evaluate(self,
+                 results,
+                 metric='mIoU',
+                 logger=None,
+                 imgfile_prefix=None,
+                 efficient_test=False):
+        """Evaluation in Cityscapes/default protocol.
+
+        Args:
+            results (list): Testing results of the dataset.
+            metric (str | list[str]): Metrics to be evaluated.
+            logger (logging.Logger | None | str): Logger used for printing
+                related information during evaluation. Default: None.
+            imgfile_prefix (str | None): The prefix of output image file,
+                for cityscapes evaluation only. It includes the file path and
+                the prefix of filename, e.g., "a/b/prefix".
+                If results are evaluated with cityscapes protocol, it would be
+                the prefix of output png files. The output files would be
+                png images under folder "a/b/prefix/xxx.png", where "xxx" is
+                the image name of cityscapes. If not specified, a temp file
+                will be created for evaluation.
+                Default: None.
+
+        Returns:
+            dict[str, float]: Cityscapes/default metrics.
+        """
+
+        eval_results = dict()
+        metrics = metric.copy() if isinstance(metric, list) else [metric]
+        if 'cityscapes' in metrics:
+            eval_results.update(
+                self._evaluate_cityscapes(results, logger, imgfile_prefix))
+            metrics.remove('cityscapes')
+        if len(metrics) > 0:
+            eval_results.update(
+                super(CityscapesDataset,
+                      self).evaluate(results, metrics, logger, efficient_test))
+
+        return eval_results
+
+    def _evaluate_cityscapes(self, results, logger, imgfile_prefix):
+        """Evaluation in Cityscapes protocol.
+
+        Args:
+            results (list): Testing results of the dataset.
+            logger (logging.Logger | str | None): Logger used for printing
+                related information during evaluation. Default: None.
+            imgfile_prefix (str | None): The prefix of output image file
+
+        Returns:
+            dict[str: float]: Cityscapes evaluation results.
+        """
+        try:
+            import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval  # noqa
+        except ImportError:
+            raise ImportError('Please run "pip install cityscapesscripts" to '
+                              'install cityscapesscripts first.')
+        msg = 'Evaluating in Cityscapes style'
+        if logger is None:
+            msg = '\n' + msg
+        print_log(msg, logger=logger)
+
+        result_files, tmp_dir = self.format_results(results, imgfile_prefix)
+
+        if tmp_dir is None:
+            result_dir = imgfile_prefix
+        else:
+            result_dir = tmp_dir.name
+
+        eval_results = dict()
+        print_log(f'Evaluating results under {result_dir} ...', logger=logger)
+
+        CSEval.args.evalInstLevelScore = True
+        CSEval.args.predictionPath = osp.abspath(result_dir)
+        CSEval.args.evalPixelAccuracy = True
+        CSEval.args.JSONOutput = False
+
+        seg_map_list = []
+        pred_list = []
+
+        # when evaluating with official cityscapesscripts,
+        # **_gtFine_labelIds.png is used
+        for seg_map in mmcv.scandir(
+                self.ann_dir, 'gtFine_labelIds.png', recursive=True):
+            seg_map_list.append(osp.join(self.ann_dir, seg_map))
+            pred_list.append(CSEval.getPrediction(CSEval.args, seg_map))
+
+        eval_results.update(
+            CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args))
+
+        if tmp_dir is not None:
+            tmp_dir.cleanup()
+
+        return eval_results
diff --git a/annotator/uniformer/mmseg/datasets/custom.py b/annotator/uniformer/mmseg/datasets/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8eb2a709cc7a3a68fc6a1e3a1ad98faef4c5b7b
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/custom.py
@@ -0,0 +1,400 @@
+import os
+import os.path as osp
+from collections import OrderedDict
+from functools import reduce
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from annotator.uniformer.mmcv.utils import print_log
+from prettytable import PrettyTable
+from torch.utils.data import Dataset
+
+from annotator.uniformer.mmseg.core import eval_metrics
+from annotator.uniformer.mmseg.utils import get_root_logger
+from .builder import DATASETS
+from .pipelines import Compose
+
+
+@DATASETS.register_module()
+class CustomDataset(Dataset):
+    """Custom dataset for semantic segmentation. An example of file structure
+    is as followed.
+
+    .. code-block:: none
+
+        ├── data
+        │   ├── my_dataset
+        │   │   ├── img_dir
+        │   │   │   ├── train
+        │   │   │   │   ├── xxx{img_suffix}
+        │   │   │   │   ├── yyy{img_suffix}
+        │   │   │   │   ├── zzz{img_suffix}
+        │   │   │   ├── val
+        │   │   ├── ann_dir
+        │   │   │   ├── train
+        │   │   │   │   ├── xxx{seg_map_suffix}
+        │   │   │   │   ├── yyy{seg_map_suffix}
+        │   │   │   │   ├── zzz{seg_map_suffix}
+        │   │   │   ├── val
+
+    The img/gt_semantic_seg pair of CustomDataset should be of the same
+    except suffix. A valid img/gt_semantic_seg filename pair should be like
+    ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
+    in the suffix). If split is given, then ``xxx`` is specified in txt file.
+    Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
+    Please refer to ``docs/tutorials/new_dataset.md`` for more details.
+
+
+    Args:
+        pipeline (list[dict]): Processing pipeline
+        img_dir (str): Path to image directory
+        img_suffix (str): Suffix of images. Default: '.jpg'
+        ann_dir (str, optional): Path to annotation directory. Default: None
+        seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
+        split (str, optional): Split txt file. If split is specified, only
+            file with suffix in the splits will be loaded. Otherwise, all
+            images in img_dir/ann_dir will be loaded. Default: None
+        data_root (str, optional): Data root for img_dir/ann_dir. Default:
+            None.
+        test_mode (bool): If test_mode=True, gt wouldn't be loaded.
+        ignore_index (int): The label index to be ignored. Default: 255
+        reduce_zero_label (bool): Whether to mark label zero as ignored.
+            Default: False
+        classes (str | Sequence[str], optional): Specify classes to load.
+            If is None, ``cls.CLASSES`` will be used. Default: None.
+        palette (Sequence[Sequence[int]]] | np.ndarray | None):
+            The palette of segmentation map. If None is given, and
+            self.PALETTE is None, random palette will be generated.
+            Default: None
+    """
+
+    CLASSES = None
+
+    PALETTE = None
+
+    def __init__(self,
+                 pipeline,
+                 img_dir,
+                 img_suffix='.jpg',
+                 ann_dir=None,
+                 seg_map_suffix='.png',
+                 split=None,
+                 data_root=None,
+                 test_mode=False,
+                 ignore_index=255,
+                 reduce_zero_label=False,
+                 classes=None,
+                 palette=None):
+        self.pipeline = Compose(pipeline)
+        self.img_dir = img_dir
+        self.img_suffix = img_suffix
+        self.ann_dir = ann_dir
+        self.seg_map_suffix = seg_map_suffix
+        self.split = split
+        self.data_root = data_root
+        self.test_mode = test_mode
+        self.ignore_index = ignore_index
+        self.reduce_zero_label = reduce_zero_label
+        self.label_map = None
+        self.CLASSES, self.PALETTE = self.get_classes_and_palette(
+            classes, palette)
+
+        # join paths if data_root is specified
+        if self.data_root is not None:
+            if not osp.isabs(self.img_dir):
+                self.img_dir = osp.join(self.data_root, self.img_dir)
+            if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
+                self.ann_dir = osp.join(self.data_root, self.ann_dir)
+            if not (self.split is None or osp.isabs(self.split)):
+                self.split = osp.join(self.data_root, self.split)
+
+        # load annotations
+        self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
+                                               self.ann_dir,
+                                               self.seg_map_suffix, self.split)
+
+    def __len__(self):
+        """Total number of samples of data."""
+        return len(self.img_infos)
+
+    def load_annotations(self, img_dir, img_suffix, ann_dir, seg_map_suffix,
+                         split):
+        """Load annotation from directory.
+
+        Args:
+            img_dir (str): Path to image directory
+            img_suffix (str): Suffix of images.
+            ann_dir (str|None): Path to annotation directory.
+            seg_map_suffix (str|None): Suffix of segmentation maps.
+            split (str|None): Split txt file. If split is specified, only file
+                with suffix in the splits will be loaded. Otherwise, all images
+                in img_dir/ann_dir will be loaded. Default: None
+
+        Returns:
+            list[dict]: All image info of dataset.
+        """
+
+        img_infos = []
+        if split is not None:
+            with open(split) as f:
+                for line in f:
+                    img_name = line.strip()
+                    img_info = dict(filename=img_name + img_suffix)
+                    if ann_dir is not None:
+                        seg_map = img_name + seg_map_suffix
+                        img_info['ann'] = dict(seg_map=seg_map)
+                    img_infos.append(img_info)
+        else:
+            for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
+                img_info = dict(filename=img)
+                if ann_dir is not None:
+                    seg_map = img.replace(img_suffix, seg_map_suffix)
+                    img_info['ann'] = dict(seg_map=seg_map)
+                img_infos.append(img_info)
+
+        print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger())
+        return img_infos
+
+    def get_ann_info(self, idx):
+        """Get annotation by index.
+
+        Args:
+            idx (int): Index of data.
+
+        Returns:
+            dict: Annotation info of specified index.
+        """
+
+        return self.img_infos[idx]['ann']
+
+    def pre_pipeline(self, results):
+        """Prepare results dict for pipeline."""
+        results['seg_fields'] = []
+        results['img_prefix'] = self.img_dir
+        results['seg_prefix'] = self.ann_dir
+        if self.custom_classes:
+            results['label_map'] = self.label_map
+
+    def __getitem__(self, idx):
+        """Get training/test data after pipeline.
+
+        Args:
+            idx (int): Index of data.
+
+        Returns:
+            dict: Training/test data (with annotation if `test_mode` is set
+                False).
+        """
+
+        if self.test_mode:
+            return self.prepare_test_img(idx)
+        else:
+            return self.prepare_train_img(idx)
+
+    def prepare_train_img(self, idx):
+        """Get training data and annotations after pipeline.
+
+        Args:
+            idx (int): Index of data.
+
+        Returns:
+            dict: Training data and annotation after pipeline with new keys
+                introduced by pipeline.
+        """
+
+        img_info = self.img_infos[idx]
+        ann_info = self.get_ann_info(idx)
+        results = dict(img_info=img_info, ann_info=ann_info)
+        self.pre_pipeline(results)
+        return self.pipeline(results)
+
+    def prepare_test_img(self, idx):
+        """Get testing data after pipeline.
+
+        Args:
+            idx (int): Index of data.
+
+        Returns:
+            dict: Testing data after pipeline with new keys introduced by
+                pipeline.
+        """
+
+        img_info = self.img_infos[idx]
+        results = dict(img_info=img_info)
+        self.pre_pipeline(results)
+        return self.pipeline(results)
+
+    def format_results(self, results, **kwargs):
+        """Place holder to format result to dataset specific output."""
+
+    def get_gt_seg_maps(self, efficient_test=False):
+        """Get ground truth segmentation maps for evaluation."""
+        gt_seg_maps = []
+        for img_info in self.img_infos:
+            seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map'])
+            if efficient_test:
+                gt_seg_map = seg_map
+            else:
+                gt_seg_map = mmcv.imread(
+                    seg_map, flag='unchanged', backend='pillow')
+            gt_seg_maps.append(gt_seg_map)
+        return gt_seg_maps
+
+    def get_classes_and_palette(self, classes=None, palette=None):
+        """Get class names of current dataset.
+
+        Args:
+            classes (Sequence[str] | str | None): If classes is None, use
+                default CLASSES defined by builtin dataset. If classes is a
+                string, take it as a file name. The file contains the name of
+                classes where each line contains one class name. If classes is
+                a tuple or list, override the CLASSES defined by the dataset.
+            palette (Sequence[Sequence[int]]] | np.ndarray | None):
+                The palette of segmentation map. If None is given, random
+                palette will be generated. Default: None
+        """
+        if classes is None:
+            self.custom_classes = False
+            return self.CLASSES, self.PALETTE
+
+        self.custom_classes = True
+        if isinstance(classes, str):
+            # take it as a file path
+            class_names = mmcv.list_from_file(classes)
+        elif isinstance(classes, (tuple, list)):
+            class_names = classes
+        else:
+            raise ValueError(f'Unsupported type {type(classes)} of classes.')
+
+        if self.CLASSES:
+            if not set(classes).issubset(self.CLASSES):
+                raise ValueError('classes is not a subset of CLASSES.')
+
+            # dictionary, its keys are the old label ids and its values
+            # are the new label ids.
+            # used for changing pixel labels in load_annotations.
+            self.label_map = {}
+            for i, c in enumerate(self.CLASSES):
+                if c not in class_names:
+                    self.label_map[i] = -1
+                else:
+                    self.label_map[i] = classes.index(c)
+
+        palette = self.get_palette_for_custom_classes(class_names, palette)
+
+        return class_names, palette
+
+    def get_palette_for_custom_classes(self, class_names, palette=None):
+
+        if self.label_map is not None:
+            # return subset of palette
+            palette = []
+            for old_id, new_id in sorted(
+                    self.label_map.items(), key=lambda x: x[1]):
+                if new_id != -1:
+                    palette.append(self.PALETTE[old_id])
+            palette = type(self.PALETTE)(palette)
+
+        elif palette is None:
+            if self.PALETTE is None:
+                palette = np.random.randint(0, 255, size=(len(class_names), 3))
+            else:
+                palette = self.PALETTE
+
+        return palette
+
+    def evaluate(self,
+                 results,
+                 metric='mIoU',
+                 logger=None,
+                 efficient_test=False,
+                 **kwargs):
+        """Evaluate the dataset.
+
+        Args:
+            results (list): Testing results of the dataset.
+            metric (str | list[str]): Metrics to be evaluated. 'mIoU',
+                'mDice' and 'mFscore' are supported.
+            logger (logging.Logger | None | str): Logger used for printing
+                related information during evaluation. Default: None.
+
+        Returns:
+            dict[str, float]: Default metrics.
+        """
+
+        if isinstance(metric, str):
+            metric = [metric]
+        allowed_metrics = ['mIoU', 'mDice', 'mFscore']
+        if not set(metric).issubset(set(allowed_metrics)):
+            raise KeyError('metric {} is not supported'.format(metric))
+        eval_results = {}
+        gt_seg_maps = self.get_gt_seg_maps(efficient_test)
+        if self.CLASSES is None:
+            num_classes = len(
+                reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
+        else:
+            num_classes = len(self.CLASSES)
+        ret_metrics = eval_metrics(
+            results,
+            gt_seg_maps,
+            num_classes,
+            self.ignore_index,
+            metric,
+            label_map=self.label_map,
+            reduce_zero_label=self.reduce_zero_label)
+
+        if self.CLASSES is None:
+            class_names = tuple(range(num_classes))
+        else:
+            class_names = self.CLASSES
+
+        # summary table
+        ret_metrics_summary = OrderedDict({
+            ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
+            for ret_metric, ret_metric_value in ret_metrics.items()
+        })
+
+        # each class table
+        ret_metrics.pop('aAcc', None)
+        ret_metrics_class = OrderedDict({
+            ret_metric: np.round(ret_metric_value * 100, 2)
+            for ret_metric, ret_metric_value in ret_metrics.items()
+        })
+        ret_metrics_class.update({'Class': class_names})
+        ret_metrics_class.move_to_end('Class', last=False)
+
+        # for logger
+        class_table_data = PrettyTable()
+        for key, val in ret_metrics_class.items():
+            class_table_data.add_column(key, val)
+
+        summary_table_data = PrettyTable()
+        for key, val in ret_metrics_summary.items():
+            if key == 'aAcc':
+                summary_table_data.add_column(key, [val])
+            else:
+                summary_table_data.add_column('m' + key, [val])
+
+        print_log('per class results:', logger)
+        print_log('\n' + class_table_data.get_string(), logger=logger)
+        print_log('Summary:', logger)
+        print_log('\n' + summary_table_data.get_string(), logger=logger)
+
+        # each metric dict
+        for key, value in ret_metrics_summary.items():
+            if key == 'aAcc':
+                eval_results[key] = value / 100.0
+            else:
+                eval_results['m' + key] = value / 100.0
+
+        ret_metrics_class.pop('Class', None)
+        for key, value in ret_metrics_class.items():
+            eval_results.update({
+                key + '.' + str(name): value[idx] / 100.0
+                for idx, name in enumerate(class_names)
+            })
+
+        if mmcv.is_list_of(results, str):
+            for file_name in results:
+                os.remove(file_name)
+        return eval_results
diff --git a/annotator/uniformer/mmseg/datasets/dataset_wrappers.py b/annotator/uniformer/mmseg/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6a5e957ec3b44465432617cf6e8f0b86a8a5efa
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/dataset_wrappers.py
@@ -0,0 +1,50 @@
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+from .builder import DATASETS
+
+
+@DATASETS.register_module()
+class ConcatDataset(_ConcatDataset):
+    """A wrapper of concatenated dataset.
+
+    Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+    concat the group flag for image aspect ratio.
+
+    Args:
+        datasets (list[:obj:`Dataset`]): A list of datasets.
+    """
+
+    def __init__(self, datasets):
+        super(ConcatDataset, self).__init__(datasets)
+        self.CLASSES = datasets[0].CLASSES
+        self.PALETTE = datasets[0].PALETTE
+
+
+@DATASETS.register_module()
+class RepeatDataset(object):
+    """A wrapper of repeated dataset.
+
+    The length of repeated dataset will be `times` larger than the original
+    dataset. This is useful when the data loading time is long but the dataset
+    is small. Using RepeatDataset can reduce the data loading time between
+    epochs.
+
+    Args:
+        dataset (:obj:`Dataset`): The dataset to be repeated.
+        times (int): Repeat times.
+    """
+
+    def __init__(self, dataset, times):
+        self.dataset = dataset
+        self.times = times
+        self.CLASSES = dataset.CLASSES
+        self.PALETTE = dataset.PALETTE
+        self._ori_len = len(self.dataset)
+
+    def __getitem__(self, idx):
+        """Get item from original dataset."""
+        return self.dataset[idx % self._ori_len]
+
+    def __len__(self):
+        """The length is multiplied by ``times``"""
+        return self.times * self._ori_len
diff --git a/annotator/uniformer/mmseg/datasets/drive.py b/annotator/uniformer/mmseg/datasets/drive.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cbfda8ae74bdf26c5aef197ff2866a7c7ad0cfd
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/drive.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class DRIVEDataset(CustomDataset):
+    """DRIVE dataset.
+
+    In segmentation map annotation for DRIVE, 0 stands for background, which is
+    included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+    ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+    '_manual1.png'.
+    """
+
+    CLASSES = ('background', 'vessel')
+
+    PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+    def __init__(self, **kwargs):
+        super(DRIVEDataset, self).__init__(
+            img_suffix='.png',
+            seg_map_suffix='_manual1.png',
+            reduce_zero_label=False,
+            **kwargs)
+        assert osp.exists(self.img_dir)
diff --git a/annotator/uniformer/mmseg/datasets/hrf.py b/annotator/uniformer/mmseg/datasets/hrf.py
new file mode 100644
index 0000000000000000000000000000000000000000..923203b51377f9344277fc561803d7a78bd2c684
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/hrf.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class HRFDataset(CustomDataset):
+    """HRF dataset.
+
+    In segmentation map annotation for HRF, 0 stands for background, which is
+    included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+    ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+    '.png'.
+    """
+
+    CLASSES = ('background', 'vessel')
+
+    PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+    def __init__(self, **kwargs):
+        super(HRFDataset, self).__init__(
+            img_suffix='.png',
+            seg_map_suffix='.png',
+            reduce_zero_label=False,
+            **kwargs)
+        assert osp.exists(self.img_dir)
diff --git a/annotator/uniformer/mmseg/datasets/pascal_context.py b/annotator/uniformer/mmseg/datasets/pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..541a63c66a13fb16fd52921e755715ad8d078fdd
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pascal_context.py
@@ -0,0 +1,103 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class PascalContextDataset(CustomDataset):
+    """PascalContext dataset.
+
+    In segmentation map annotation for PascalContext, 0 stands for background,
+    which is included in 60 categories. ``reduce_zero_label`` is fixed to
+    False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+    fixed to '.png'.
+
+    Args:
+        split (str): Split txt file for PascalContext.
+    """
+
+    CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
+               'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
+               'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
+               'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
+               'floor', 'flower', 'food', 'grass', 'ground', 'horse',
+               'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
+               'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
+               'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
+               'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
+               'window', 'wood')
+
+    PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
+               [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
+               [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
+               [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
+               [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
+               [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
+               [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
+               [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
+               [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
+               [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
+               [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
+               [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
+               [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
+               [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
+               [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+
+    def __init__(self, split, **kwargs):
+        super(PascalContextDataset, self).__init__(
+            img_suffix='.jpg',
+            seg_map_suffix='.png',
+            split=split,
+            reduce_zero_label=False,
+            **kwargs)
+        assert osp.exists(self.img_dir) and self.split is not None
+
+
+@DATASETS.register_module()
+class PascalContextDataset59(CustomDataset):
+    """PascalContext dataset.
+
+    In segmentation map annotation for PascalContext, 0 stands for background,
+    which is included in 60 categories. ``reduce_zero_label`` is fixed to
+    False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
+    fixed to '.png'.
+
+    Args:
+        split (str): Split txt file for PascalContext.
+    """
+
+    CLASSES = ('aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle',
+               'bird', 'boat', 'book', 'bottle', 'building', 'bus', 'cabinet',
+               'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow',
+               'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower',
+               'food', 'grass', 'ground', 'horse', 'keyboard', 'light',
+               'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform',
+               'pottedplant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk',
+               'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train',
+               'tree', 'truck', 'tvmonitor', 'wall', 'water', 'window', 'wood')
+
+    PALETTE = [[180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3],
+               [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230],
+               [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61],
+               [120, 120, 70], [8, 255, 51], [255, 6, 82], [143, 255, 140],
+               [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200],
+               [61, 230, 250], [255, 6, 51], [11, 102, 255], [255, 7, 71],
+               [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92],
+               [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6],
+               [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8],
+               [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8],
+               [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255],
+               [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140],
+               [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0],
+               [255, 224, 0], [153, 255, 0], [0, 0, 255], [255, 71, 0],
+               [0, 235, 255], [0, 173, 255], [31, 0, 255]]
+
+    def __init__(self, split, **kwargs):
+        super(PascalContextDataset59, self).__init__(
+            img_suffix='.jpg',
+            seg_map_suffix='.png',
+            split=split,
+            reduce_zero_label=True,
+            **kwargs)
+        assert osp.exists(self.img_dir) and self.split is not None
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/__init__.py b/annotator/uniformer/mmseg/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9046b07bb4ddea7a707a392b42e72db7c9df67
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/__init__.py
@@ -0,0 +1,16 @@
+from .compose import Compose
+from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor,
+                        Transpose, to_tensor)
+from .loading import LoadAnnotations, LoadImageFromFile
+from .test_time_aug import MultiScaleFlipAug
+from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
+                         PhotoMetricDistortion, RandomCrop, RandomFlip,
+                         RandomRotate, Rerange, Resize, RGB2Gray, SegRescale)
+
+__all__ = [
+    'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
+    'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
+    'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
+    'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
+    'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray'
+]
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/compose.py b/annotator/uniformer/mmseg/datasets/pipelines/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbfcbb925c6d4ebf849328b9f94ef6fc24359bf5
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/compose.py
@@ -0,0 +1,51 @@
+import collections
+
+from annotator.uniformer.mmcv.utils import build_from_cfg
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Compose(object):
+    """Compose multiple transforms sequentially.
+
+    Args:
+        transforms (Sequence[dict | callable]): Sequence of transform object or
+            config dict to be composed.
+    """
+
+    def __init__(self, transforms):
+        assert isinstance(transforms, collections.abc.Sequence)
+        self.transforms = []
+        for transform in transforms:
+            if isinstance(transform, dict):
+                transform = build_from_cfg(transform, PIPELINES)
+                self.transforms.append(transform)
+            elif callable(transform):
+                self.transforms.append(transform)
+            else:
+                raise TypeError('transform must be callable or a dict')
+
+    def __call__(self, data):
+        """Call function to apply transforms sequentially.
+
+        Args:
+            data (dict): A result dict contains the data to transform.
+
+        Returns:
+           dict: Transformed data.
+        """
+
+        for t in self.transforms:
+            data = t(data)
+            if data is None:
+                return None
+        return data
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + '('
+        for t in self.transforms:
+            format_string += '\n'
+            format_string += f'    {t}'
+        format_string += '\n)'
+        return format_string
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/formating.py b/annotator/uniformer/mmseg/datasets/pipelines/formating.py
new file mode 100644
index 0000000000000000000000000000000000000000..97db85f4f9db39fb86ba77ead7d1a8407d810adb
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/formating.py
@@ -0,0 +1,288 @@
+from collections.abc import Sequence
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+from annotator.uniformer.mmcv.parallel import DataContainer as DC
+
+from ..builder import PIPELINES
+
+
+def to_tensor(data):
+    """Convert objects of various python types to :obj:`torch.Tensor`.
+
+    Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+    :class:`Sequence`, :class:`int` and :class:`float`.
+
+    Args:
+        data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+            be converted.
+    """
+
+    if isinstance(data, torch.Tensor):
+        return data
+    elif isinstance(data, np.ndarray):
+        return torch.from_numpy(data)
+    elif isinstance(data, Sequence) and not mmcv.is_str(data):
+        return torch.tensor(data)
+    elif isinstance(data, int):
+        return torch.LongTensor([data])
+    elif isinstance(data, float):
+        return torch.FloatTensor([data])
+    else:
+        raise TypeError(f'type {type(data)} cannot be converted to tensor.')
+
+
+@PIPELINES.register_module()
+class ToTensor(object):
+    """Convert some results to :obj:`torch.Tensor` by given keys.
+
+    Args:
+        keys (Sequence[str]): Keys that need to be converted to Tensor.
+    """
+
+    def __init__(self, keys):
+        self.keys = keys
+
+    def __call__(self, results):
+        """Call function to convert data in results to :obj:`torch.Tensor`.
+
+        Args:
+            results (dict): Result dict contains the data to convert.
+
+        Returns:
+            dict: The result dict contains the data converted
+                to :obj:`torch.Tensor`.
+        """
+
+        for key in self.keys:
+            results[key] = to_tensor(results[key])
+        return results
+
+    def __repr__(self):
+        return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class ImageToTensor(object):
+    """Convert image to :obj:`torch.Tensor` by given keys.
+
+    The dimension order of input image is (H, W, C). The pipeline will convert
+    it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
+    (1, H, W).
+
+    Args:
+        keys (Sequence[str]): Key of images to be converted to Tensor.
+    """
+
+    def __init__(self, keys):
+        self.keys = keys
+
+    def __call__(self, results):
+        """Call function to convert image in results to :obj:`torch.Tensor` and
+        transpose the channel order.
+
+        Args:
+            results (dict): Result dict contains the image data to convert.
+
+        Returns:
+            dict: The result dict contains the image converted
+                to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+        """
+
+        for key in self.keys:
+            img = results[key]
+            if len(img.shape) < 3:
+                img = np.expand_dims(img, -1)
+            results[key] = to_tensor(img.transpose(2, 0, 1))
+        return results
+
+    def __repr__(self):
+        return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class Transpose(object):
+    """Transpose some results by given keys.
+
+    Args:
+        keys (Sequence[str]): Keys of results to be transposed.
+        order (Sequence[int]): Order of transpose.
+    """
+
+    def __init__(self, keys, order):
+        self.keys = keys
+        self.order = order
+
+    def __call__(self, results):
+        """Call function to convert image in results to :obj:`torch.Tensor` and
+        transpose the channel order.
+
+        Args:
+            results (dict): Result dict contains the image data to convert.
+
+        Returns:
+            dict: The result dict contains the image converted
+                to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+        """
+
+        for key in self.keys:
+            results[key] = results[key].transpose(self.order)
+        return results
+
+    def __repr__(self):
+        return self.__class__.__name__ + \
+               f'(keys={self.keys}, order={self.order})'
+
+
+@PIPELINES.register_module()
+class ToDataContainer(object):
+    """Convert results to :obj:`mmcv.DataContainer` by given fields.
+
+    Args:
+        fields (Sequence[dict]): Each field is a dict like
+            ``dict(key='xxx', **kwargs)``. The ``key`` in result will
+            be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
+            Default: ``(dict(key='img', stack=True),
+            dict(key='gt_semantic_seg'))``.
+    """
+
+    def __init__(self,
+                 fields=(dict(key='img',
+                              stack=True), dict(key='gt_semantic_seg'))):
+        self.fields = fields
+
+    def __call__(self, results):
+        """Call function to convert data in results to
+        :obj:`mmcv.DataContainer`.
+
+        Args:
+            results (dict): Result dict contains the data to convert.
+
+        Returns:
+            dict: The result dict contains the data converted to
+                :obj:`mmcv.DataContainer`.
+        """
+
+        for field in self.fields:
+            field = field.copy()
+            key = field.pop('key')
+            results[key] = DC(results[key], **field)
+        return results
+
+    def __repr__(self):
+        return self.__class__.__name__ + f'(fields={self.fields})'
+
+
+@PIPELINES.register_module()
+class DefaultFormatBundle(object):
+    """Default formatting bundle.
+
+    It simplifies the pipeline of formatting common fields, including "img"
+    and "gt_semantic_seg". These fields are formatted as follows.
+
+    - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
+    - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor,
+                       (3)to DataContainer (stack=True)
+    """
+
+    def __call__(self, results):
+        """Call function to transform and format common fields in results.
+
+        Args:
+            results (dict): Result dict contains the data to convert.
+
+        Returns:
+            dict: The result dict contains the data that is formatted with
+                default bundle.
+        """
+
+        if 'img' in results:
+            img = results['img']
+            if len(img.shape) < 3:
+                img = np.expand_dims(img, -1)
+            img = np.ascontiguousarray(img.transpose(2, 0, 1))
+            results['img'] = DC(to_tensor(img), stack=True)
+        if 'gt_semantic_seg' in results:
+            # convert to long
+            results['gt_semantic_seg'] = DC(
+                to_tensor(results['gt_semantic_seg'][None,
+                                                     ...].astype(np.int64)),
+                stack=True)
+        return results
+
+    def __repr__(self):
+        return self.__class__.__name__
+
+
+@PIPELINES.register_module()
+class Collect(object):
+    """Collect data from the loader relevant to the specific task.
+
+    This is usually the last stage of the data loader pipeline. Typically keys
+    is set to some subset of "img", "gt_semantic_seg".
+
+    The "img_meta" item is always populated.  The contents of the "img_meta"
+    dictionary depends on "meta_keys". By default this includes:
+
+        - "img_shape": shape of the image input to the network as a tuple
+            (h, w, c).  Note that images may be zero padded on the bottom/right
+            if the batch tensor is larger than this shape.
+
+        - "scale_factor": a float indicating the preprocessing scale
+
+        - "flip": a boolean indicating if image flip transform was used
+
+        - "filename": path to the image file
+
+        - "ori_shape": original shape of the image as a tuple (h, w, c)
+
+        - "pad_shape": image shape after padding
+
+        - "img_norm_cfg": a dict of normalization information:
+            - mean - per channel mean subtraction
+            - std - per channel std divisor
+            - to_rgb - bool indicating if bgr was converted to rgb
+
+    Args:
+        keys (Sequence[str]): Keys of results to be collected in ``data``.
+        meta_keys (Sequence[str], optional): Meta keys to be converted to
+            ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
+            Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
+            'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+            'img_norm_cfg')``
+    """
+
+    def __init__(self,
+                 keys,
+                 meta_keys=('filename', 'ori_filename', 'ori_shape',
+                            'img_shape', 'pad_shape', 'scale_factor', 'flip',
+                            'flip_direction', 'img_norm_cfg')):
+        self.keys = keys
+        self.meta_keys = meta_keys
+
+    def __call__(self, results):
+        """Call function to collect keys in results. The keys in ``meta_keys``
+        will be converted to :obj:mmcv.DataContainer.
+
+        Args:
+            results (dict): Result dict contains the data to collect.
+
+        Returns:
+            dict: The result dict contains the following keys
+                - keys in``self.keys``
+                - ``img_metas``
+        """
+
+        data = {}
+        img_meta = {}
+        for key in self.meta_keys:
+            img_meta[key] = results[key]
+        data['img_metas'] = DC(img_meta, cpu_only=True)
+        for key in self.keys:
+            data[key] = results[key]
+        return data
+
+    def __repr__(self):
+        return self.__class__.__name__ + \
+               f'(keys={self.keys}, meta_keys={self.meta_keys})'
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/loading.py b/annotator/uniformer/mmseg/datasets/pipelines/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3692ae91f19b9c7ccf6023168788ff42c9e93e3
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/loading.py
@@ -0,0 +1,153 @@
+import os.path as osp
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class LoadImageFromFile(object):
+    """Load an image from file.
+
+    Required keys are "img_prefix" and "img_info" (a dict that must contain the
+    key "filename"). Added or updated keys are "filename", "img", "img_shape",
+    "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+    "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+    Args:
+        to_float32 (bool): Whether to convert the loaded image to a float32
+            numpy array. If set to False, the loaded image is an uint8 array.
+            Defaults to False.
+        color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+            Defaults to 'color'.
+        file_client_args (dict): Arguments to instantiate a FileClient.
+            See :class:`mmcv.fileio.FileClient` for details.
+            Defaults to ``dict(backend='disk')``.
+        imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+            'cv2'
+    """
+
+    def __init__(self,
+                 to_float32=False,
+                 color_type='color',
+                 file_client_args=dict(backend='disk'),
+                 imdecode_backend='cv2'):
+        self.to_float32 = to_float32
+        self.color_type = color_type
+        self.file_client_args = file_client_args.copy()
+        self.file_client = None
+        self.imdecode_backend = imdecode_backend
+
+    def __call__(self, results):
+        """Call functions to load image and get image meta information.
+
+        Args:
+            results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+
+        Returns:
+            dict: The dict contains loaded image and meta information.
+        """
+
+        if self.file_client is None:
+            self.file_client = mmcv.FileClient(**self.file_client_args)
+
+        if results.get('img_prefix') is not None:
+            filename = osp.join(results['img_prefix'],
+                                results['img_info']['filename'])
+        else:
+            filename = results['img_info']['filename']
+        img_bytes = self.file_client.get(filename)
+        img = mmcv.imfrombytes(
+            img_bytes, flag=self.color_type, backend=self.imdecode_backend)
+        if self.to_float32:
+            img = img.astype(np.float32)
+
+        results['filename'] = filename
+        results['ori_filename'] = results['img_info']['filename']
+        results['img'] = img
+        results['img_shape'] = img.shape
+        results['ori_shape'] = img.shape
+        # Set initial values for default meta_keys
+        results['pad_shape'] = img.shape
+        results['scale_factor'] = 1.0
+        num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+        results['img_norm_cfg'] = dict(
+            mean=np.zeros(num_channels, dtype=np.float32),
+            std=np.ones(num_channels, dtype=np.float32),
+            to_rgb=False)
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(to_float32={self.to_float32},'
+        repr_str += f"color_type='{self.color_type}',"
+        repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+        return repr_str
+
+
+@PIPELINES.register_module()
+class LoadAnnotations(object):
+    """Load annotations for semantic segmentation.
+
+    Args:
+        reduce_zero_label (bool): Whether reduce all label value by 1.
+            Usually used for datasets where 0 is background label.
+            Default: False.
+        file_client_args (dict): Arguments to instantiate a FileClient.
+            See :class:`mmcv.fileio.FileClient` for details.
+            Defaults to ``dict(backend='disk')``.
+        imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default:
+            'pillow'
+    """
+
+    def __init__(self,
+                 reduce_zero_label=False,
+                 file_client_args=dict(backend='disk'),
+                 imdecode_backend='pillow'):
+        self.reduce_zero_label = reduce_zero_label
+        self.file_client_args = file_client_args.copy()
+        self.file_client = None
+        self.imdecode_backend = imdecode_backend
+
+    def __call__(self, results):
+        """Call function to load multiple types annotations.
+
+        Args:
+            results (dict): Result dict from :obj:`mmseg.CustomDataset`.
+
+        Returns:
+            dict: The dict contains loaded semantic segmentation annotations.
+        """
+
+        if self.file_client is None:
+            self.file_client = mmcv.FileClient(**self.file_client_args)
+
+        if results.get('seg_prefix', None) is not None:
+            filename = osp.join(results['seg_prefix'],
+                                results['ann_info']['seg_map'])
+        else:
+            filename = results['ann_info']['seg_map']
+        img_bytes = self.file_client.get(filename)
+        gt_semantic_seg = mmcv.imfrombytes(
+            img_bytes, flag='unchanged',
+            backend=self.imdecode_backend).squeeze().astype(np.uint8)
+        # modify if custom classes
+        if results.get('label_map', None) is not None:
+            for old_id, new_id in results['label_map'].items():
+                gt_semantic_seg[gt_semantic_seg == old_id] = new_id
+        # reduce zero_label
+        if self.reduce_zero_label:
+            # avoid using underflow conversion
+            gt_semantic_seg[gt_semantic_seg == 0] = 255
+            gt_semantic_seg = gt_semantic_seg - 1
+            gt_semantic_seg[gt_semantic_seg == 254] = 255
+        results['gt_semantic_seg'] = gt_semantic_seg
+        results['seg_fields'].append('gt_semantic_seg')
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(reduce_zero_label={self.reduce_zero_label},'
+        repr_str += f"imdecode_backend='{self.imdecode_backend}')"
+        return repr_str
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py b/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1611a04d9d927223c9afbe5bf68af04d62937a
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,133 @@
+import warnings
+
+import annotator.uniformer.mmcv as mmcv
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+
+@PIPELINES.register_module()
+class MultiScaleFlipAug(object):
+    """Test-time augmentation with multiple scales and flipping.
+
+    An example configuration is as followed:
+
+    .. code-block::
+
+        img_scale=(2048, 1024),
+        img_ratios=[0.5, 1.0],
+        flip=True,
+        transforms=[
+            dict(type='Resize', keep_ratio=True),
+            dict(type='RandomFlip'),
+            dict(type='Normalize', **img_norm_cfg),
+            dict(type='Pad', size_divisor=32),
+            dict(type='ImageToTensor', keys=['img']),
+            dict(type='Collect', keys=['img']),
+        ]
+
+    After MultiScaleFLipAug with above configuration, the results are wrapped
+    into lists of the same length as followed:
+
+    .. code-block::
+
+        dict(
+            img=[...],
+            img_shape=[...],
+            scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)]
+            flip=[False, True, False, True]
+            ...
+        )
+
+    Args:
+        transforms (list[dict]): Transforms to apply in each augmentation.
+        img_scale (None | tuple | list[tuple]): Images scales for resizing.
+        img_ratios (float | list[float]): Image ratios for resizing
+        flip (bool): Whether apply flip augmentation. Default: False.
+        flip_direction (str | list[str]): Flip augmentation directions,
+            options are "horizontal" and "vertical". If flip_direction is list,
+            multiple flip augmentations will be applied.
+            It has no effect when flip == False. Default: "horizontal".
+    """
+
+    def __init__(self,
+                 transforms,
+                 img_scale,
+                 img_ratios=None,
+                 flip=False,
+                 flip_direction='horizontal'):
+        self.transforms = Compose(transforms)
+        if img_ratios is not None:
+            img_ratios = img_ratios if isinstance(img_ratios,
+                                                  list) else [img_ratios]
+            assert mmcv.is_list_of(img_ratios, float)
+        if img_scale is None:
+            # mode 1: given img_scale=None and a range of image ratio
+            self.img_scale = None
+            assert mmcv.is_list_of(img_ratios, float)
+        elif isinstance(img_scale, tuple) and mmcv.is_list_of(
+                img_ratios, float):
+            assert len(img_scale) == 2
+            # mode 2: given a scale and a range of image ratio
+            self.img_scale = [(int(img_scale[0] * ratio),
+                               int(img_scale[1] * ratio))
+                              for ratio in img_ratios]
+        else:
+            # mode 3: given multiple scales
+            self.img_scale = img_scale if isinstance(img_scale,
+                                                     list) else [img_scale]
+        assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None
+        self.flip = flip
+        self.img_ratios = img_ratios
+        self.flip_direction = flip_direction if isinstance(
+            flip_direction, list) else [flip_direction]
+        assert mmcv.is_list_of(self.flip_direction, str)
+        if not self.flip and self.flip_direction != ['horizontal']:
+            warnings.warn(
+                'flip_direction has no effect when flip is set to False')
+        if (self.flip
+                and not any([t['type'] == 'RandomFlip' for t in transforms])):
+            warnings.warn(
+                'flip has no effect when RandomFlip is not in transforms')
+
+    def __call__(self, results):
+        """Call function to apply test time augment transforms on results.
+
+        Args:
+            results (dict): Result dict contains the data to transform.
+
+        Returns:
+           dict[str: list]: The augmented data, where each value is wrapped
+               into a list.
+        """
+
+        aug_data = []
+        if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
+            h, w = results['img'].shape[:2]
+            img_scale = [(int(w * ratio), int(h * ratio))
+                         for ratio in self.img_ratios]
+        else:
+            img_scale = self.img_scale
+        flip_aug = [False, True] if self.flip else [False]
+        for scale in img_scale:
+            for flip in flip_aug:
+                for direction in self.flip_direction:
+                    _results = results.copy()
+                    _results['scale'] = scale
+                    _results['flip'] = flip
+                    _results['flip_direction'] = direction
+                    data = self.transforms(_results)
+                    aug_data.append(data)
+        # list of dict to dict of list
+        aug_data_dict = {key: [] for key in aug_data[0]}
+        for data in aug_data:
+            for key, val in data.items():
+                aug_data_dict[key].append(val)
+        return aug_data_dict
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(transforms={self.transforms}, '
+        repr_str += f'img_scale={self.img_scale}, flip={self.flip})'
+        repr_str += f'flip_direction={self.flip_direction}'
+        return repr_str
diff --git a/annotator/uniformer/mmseg/datasets/pipelines/transforms.py b/annotator/uniformer/mmseg/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..94e869b252ef6d8b43604add2bbc02f034614bfb
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/pipelines/transforms.py
@@ -0,0 +1,889 @@
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+from annotator.uniformer.mmcv.utils import deprecated_api_warning, is_tuple_of
+from numpy import random
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Resize(object):
+    """Resize images & seg.
+
+    This transform resizes the input image to some scale. If the input dict
+    contains the key "scale", then the scale in the input dict is used,
+    otherwise the specified scale in the init method is used.
+
+    ``img_scale`` can be None, a tuple (single-scale) or a list of tuple
+    (multi-scale). There are 4 multiscale modes:
+
+    - ``ratio_range is not None``:
+    1. When img_scale is None, img_scale is the shape of image in results
+        (img_scale = results['img'].shape[:2]) and the image is resized based
+        on the original size. (mode 1)
+    2. When img_scale is a tuple (single-scale), randomly sample a ratio from
+        the ratio range and multiply it with the image scale. (mode 2)
+
+    - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a
+    scale from the a range. (mode 3)
+
+    - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a
+    scale from multiple scales. (mode 4)
+
+    Args:
+        img_scale (tuple or list[tuple]): Images scales for resizing.
+        multiscale_mode (str): Either "range" or "value".
+        ratio_range (tuple[float]): (min_ratio, max_ratio)
+        keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+            image.
+    """
+
+    def __init__(self,
+                 img_scale=None,
+                 multiscale_mode='range',
+                 ratio_range=None,
+                 keep_ratio=True):
+        if img_scale is None:
+            self.img_scale = None
+        else:
+            if isinstance(img_scale, list):
+                self.img_scale = img_scale
+            else:
+                self.img_scale = [img_scale]
+            assert mmcv.is_list_of(self.img_scale, tuple)
+
+        if ratio_range is not None:
+            # mode 1: given img_scale=None and a range of image ratio
+            # mode 2: given a scale and a range of image ratio
+            assert self.img_scale is None or len(self.img_scale) == 1
+        else:
+            # mode 3 and 4: given multiple scales or a range of scales
+            assert multiscale_mode in ['value', 'range']
+
+        self.multiscale_mode = multiscale_mode
+        self.ratio_range = ratio_range
+        self.keep_ratio = keep_ratio
+
+    @staticmethod
+    def random_select(img_scales):
+        """Randomly select an img_scale from given candidates.
+
+        Args:
+            img_scales (list[tuple]): Images scales for selection.
+
+        Returns:
+            (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
+                where ``img_scale`` is the selected image scale and
+                ``scale_idx`` is the selected index in the given candidates.
+        """
+
+        assert mmcv.is_list_of(img_scales, tuple)
+        scale_idx = np.random.randint(len(img_scales))
+        img_scale = img_scales[scale_idx]
+        return img_scale, scale_idx
+
+    @staticmethod
+    def random_sample(img_scales):
+        """Randomly sample an img_scale when ``multiscale_mode=='range'``.
+
+        Args:
+            img_scales (list[tuple]): Images scale range for sampling.
+                There must be two tuples in img_scales, which specify the lower
+                and upper bound of image scales.
+
+        Returns:
+            (tuple, None): Returns a tuple ``(img_scale, None)``, where
+                ``img_scale`` is sampled scale and None is just a placeholder
+                to be consistent with :func:`random_select`.
+        """
+
+        assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
+        img_scale_long = [max(s) for s in img_scales]
+        img_scale_short = [min(s) for s in img_scales]
+        long_edge = np.random.randint(
+            min(img_scale_long),
+            max(img_scale_long) + 1)
+        short_edge = np.random.randint(
+            min(img_scale_short),
+            max(img_scale_short) + 1)
+        img_scale = (long_edge, short_edge)
+        return img_scale, None
+
+    @staticmethod
+    def random_sample_ratio(img_scale, ratio_range):
+        """Randomly sample an img_scale when ``ratio_range`` is specified.
+
+        A ratio will be randomly sampled from the range specified by
+        ``ratio_range``. Then it would be multiplied with ``img_scale`` to
+        generate sampled scale.
+
+        Args:
+            img_scale (tuple): Images scale base to multiply with ratio.
+            ratio_range (tuple[float]): The minimum and maximum ratio to scale
+                the ``img_scale``.
+
+        Returns:
+            (tuple, None): Returns a tuple ``(scale, None)``, where
+                ``scale`` is sampled ratio multiplied with ``img_scale`` and
+                None is just a placeholder to be consistent with
+                :func:`random_select`.
+        """
+
+        assert isinstance(img_scale, tuple) and len(img_scale) == 2
+        min_ratio, max_ratio = ratio_range
+        assert min_ratio <= max_ratio
+        ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
+        scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
+        return scale, None
+
+    def _random_scale(self, results):
+        """Randomly sample an img_scale according to ``ratio_range`` and
+        ``multiscale_mode``.
+
+        If ``ratio_range`` is specified, a ratio will be sampled and be
+        multiplied with ``img_scale``.
+        If multiple scales are specified by ``img_scale``, a scale will be
+        sampled according to ``multiscale_mode``.
+        Otherwise, single scale will be used.
+
+        Args:
+            results (dict): Result dict from :obj:`dataset`.
+
+        Returns:
+            dict: Two new keys 'scale` and 'scale_idx` are added into
+                ``results``, which would be used by subsequent pipelines.
+        """
+
+        if self.ratio_range is not None:
+            if self.img_scale is None:
+                h, w = results['img'].shape[:2]
+                scale, scale_idx = self.random_sample_ratio((w, h),
+                                                            self.ratio_range)
+            else:
+                scale, scale_idx = self.random_sample_ratio(
+                    self.img_scale[0], self.ratio_range)
+        elif len(self.img_scale) == 1:
+            scale, scale_idx = self.img_scale[0], 0
+        elif self.multiscale_mode == 'range':
+            scale, scale_idx = self.random_sample(self.img_scale)
+        elif self.multiscale_mode == 'value':
+            scale, scale_idx = self.random_select(self.img_scale)
+        else:
+            raise NotImplementedError
+
+        results['scale'] = scale
+        results['scale_idx'] = scale_idx
+
+    def _resize_img(self, results):
+        """Resize images with ``results['scale']``."""
+        if self.keep_ratio:
+            img, scale_factor = mmcv.imrescale(
+                results['img'], results['scale'], return_scale=True)
+            # the w_scale and h_scale has minor difference
+            # a real fix should be done in the mmcv.imrescale in the future
+            new_h, new_w = img.shape[:2]
+            h, w = results['img'].shape[:2]
+            w_scale = new_w / w
+            h_scale = new_h / h
+        else:
+            img, w_scale, h_scale = mmcv.imresize(
+                results['img'], results['scale'], return_scale=True)
+        scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
+                                dtype=np.float32)
+        results['img'] = img
+        results['img_shape'] = img.shape
+        results['pad_shape'] = img.shape  # in case that there is no padding
+        results['scale_factor'] = scale_factor
+        results['keep_ratio'] = self.keep_ratio
+
+    def _resize_seg(self, results):
+        """Resize semantic segmentation map with ``results['scale']``."""
+        for key in results.get('seg_fields', []):
+            if self.keep_ratio:
+                gt_seg = mmcv.imrescale(
+                    results[key], results['scale'], interpolation='nearest')
+            else:
+                gt_seg = mmcv.imresize(
+                    results[key], results['scale'], interpolation='nearest')
+            results[key] = gt_seg
+
+    def __call__(self, results):
+        """Call function to resize images, bounding boxes, masks, semantic
+        segmentation map.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
+                'keep_ratio' keys are added into result dict.
+        """
+
+        if 'scale' not in results:
+            self._random_scale(results)
+        self._resize_img(results)
+        self._resize_seg(results)
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += (f'(img_scale={self.img_scale}, '
+                     f'multiscale_mode={self.multiscale_mode}, '
+                     f'ratio_range={self.ratio_range}, '
+                     f'keep_ratio={self.keep_ratio})')
+        return repr_str
+
+
+@PIPELINES.register_module()
+class RandomFlip(object):
+    """Flip the image & seg.
+
+    If the input dict contains the key "flip", then the flag will be used,
+    otherwise it will be randomly decided by a ratio specified in the init
+    method.
+
+    Args:
+        prob (float, optional): The flipping probability. Default: None.
+        direction(str, optional): The flipping direction. Options are
+            'horizontal' and 'vertical'. Default: 'horizontal'.
+    """
+
+    @deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
+    def __init__(self, prob=None, direction='horizontal'):
+        self.prob = prob
+        self.direction = direction
+        if prob is not None:
+            assert prob >= 0 and prob <= 1
+        assert direction in ['horizontal', 'vertical']
+
+    def __call__(self, results):
+        """Call function to flip bounding boxes, masks, semantic segmentation
+        maps.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Flipped results, 'flip', 'flip_direction' keys are added into
+                result dict.
+        """
+
+        if 'flip' not in results:
+            flip = True if np.random.rand() < self.prob else False
+            results['flip'] = flip
+        if 'flip_direction' not in results:
+            results['flip_direction'] = self.direction
+        if results['flip']:
+            # flip image
+            results['img'] = mmcv.imflip(
+                results['img'], direction=results['flip_direction'])
+
+            # flip segs
+            for key in results.get('seg_fields', []):
+                # use copy() to make numpy stride positive
+                results[key] = mmcv.imflip(
+                    results[key], direction=results['flip_direction']).copy()
+        return results
+
+    def __repr__(self):
+        return self.__class__.__name__ + f'(prob={self.prob})'
+
+
+@PIPELINES.register_module()
+class Pad(object):
+    """Pad the image & mask.
+
+    There are two padding modes: (1) pad to a fixed size and (2) pad to the
+    minimum size that is divisible by some number.
+    Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
+
+    Args:
+        size (tuple, optional): Fixed padding size.
+        size_divisor (int, optional): The divisor of padded size.
+        pad_val (float, optional): Padding value. Default: 0.
+        seg_pad_val (float, optional): Padding value of segmentation map.
+            Default: 255.
+    """
+
+    def __init__(self,
+                 size=None,
+                 size_divisor=None,
+                 pad_val=0,
+                 seg_pad_val=255):
+        self.size = size
+        self.size_divisor = size_divisor
+        self.pad_val = pad_val
+        self.seg_pad_val = seg_pad_val
+        # only one of size and size_divisor should be valid
+        assert size is not None or size_divisor is not None
+        assert size is None or size_divisor is None
+
+    def _pad_img(self, results):
+        """Pad images according to ``self.size``."""
+        if self.size is not None:
+            padded_img = mmcv.impad(
+                results['img'], shape=self.size, pad_val=self.pad_val)
+        elif self.size_divisor is not None:
+            padded_img = mmcv.impad_to_multiple(
+                results['img'], self.size_divisor, pad_val=self.pad_val)
+        results['img'] = padded_img
+        results['pad_shape'] = padded_img.shape
+        results['pad_fixed_size'] = self.size
+        results['pad_size_divisor'] = self.size_divisor
+
+    def _pad_seg(self, results):
+        """Pad masks according to ``results['pad_shape']``."""
+        for key in results.get('seg_fields', []):
+            results[key] = mmcv.impad(
+                results[key],
+                shape=results['pad_shape'][:2],
+                pad_val=self.seg_pad_val)
+
+    def __call__(self, results):
+        """Call function to pad images, masks, semantic segmentation maps.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Updated result dict.
+        """
+
+        self._pad_img(results)
+        self._pad_seg(results)
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(size={self.size}, size_divisor={self.size_divisor}, ' \
+                    f'pad_val={self.pad_val})'
+        return repr_str
+
+
+@PIPELINES.register_module()
+class Normalize(object):
+    """Normalize the image.
+
+    Added key is "img_norm_cfg".
+
+    Args:
+        mean (sequence): Mean values of 3 channels.
+        std (sequence): Std values of 3 channels.
+        to_rgb (bool): Whether to convert the image from BGR to RGB,
+            default is true.
+    """
+
+    def __init__(self, mean, std, to_rgb=True):
+        self.mean = np.array(mean, dtype=np.float32)
+        self.std = np.array(std, dtype=np.float32)
+        self.to_rgb = to_rgb
+
+    def __call__(self, results):
+        """Call function to normalize images.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Normalized results, 'img_norm_cfg' key is added into
+                result dict.
+        """
+
+        results['img'] = mmcv.imnormalize(results['img'], self.mean, self.std,
+                                          self.to_rgb)
+        results['img_norm_cfg'] = dict(
+            mean=self.mean, std=self.std, to_rgb=self.to_rgb)
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(mean={self.mean}, std={self.std}, to_rgb=' \
+                    f'{self.to_rgb})'
+        return repr_str
+
+
+@PIPELINES.register_module()
+class Rerange(object):
+    """Rerange the image pixel value.
+
+    Args:
+        min_value (float or int): Minimum value of the reranged image.
+            Default: 0.
+        max_value (float or int): Maximum value of the reranged image.
+            Default: 255.
+    """
+
+    def __init__(self, min_value=0, max_value=255):
+        assert isinstance(min_value, float) or isinstance(min_value, int)
+        assert isinstance(max_value, float) or isinstance(max_value, int)
+        assert min_value < max_value
+        self.min_value = min_value
+        self.max_value = max_value
+
+    def __call__(self, results):
+        """Call function to rerange images.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+        Returns:
+            dict: Reranged results.
+        """
+
+        img = results['img']
+        img_min_value = np.min(img)
+        img_max_value = np.max(img)
+
+        assert img_min_value < img_max_value
+        # rerange to [0, 1]
+        img = (img - img_min_value) / (img_max_value - img_min_value)
+        # rerange to [min_value, max_value]
+        img = img * (self.max_value - self.min_value) + self.min_value
+        results['img'] = img
+
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(min_value={self.min_value}, max_value={self.max_value})'
+        return repr_str
+
+
+@PIPELINES.register_module()
+class CLAHE(object):
+    """Use CLAHE method to process the image.
+
+    See `ZUIDERVELD,K. Contrast Limited Adaptive Histogram Equalization[J].
+    Graphics Gems, 1994:474-485.` for more information.
+
+    Args:
+        clip_limit (float): Threshold for contrast limiting. Default: 40.0.
+        tile_grid_size (tuple[int]): Size of grid for histogram equalization.
+            Input image will be divided into equally sized rectangular tiles.
+            It defines the number of tiles in row and column. Default: (8, 8).
+    """
+
+    def __init__(self, clip_limit=40.0, tile_grid_size=(8, 8)):
+        assert isinstance(clip_limit, (float, int))
+        self.clip_limit = clip_limit
+        assert is_tuple_of(tile_grid_size, int)
+        assert len(tile_grid_size) == 2
+        self.tile_grid_size = tile_grid_size
+
+    def __call__(self, results):
+        """Call function to Use CLAHE method process images.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Processed results.
+        """
+
+        for i in range(results['img'].shape[2]):
+            results['img'][:, :, i] = mmcv.clahe(
+                np.array(results['img'][:, :, i], dtype=np.uint8),
+                self.clip_limit, self.tile_grid_size)
+
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(clip_limit={self.clip_limit}, '\
+                    f'tile_grid_size={self.tile_grid_size})'
+        return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCrop(object):
+    """Random crop the image & seg.
+
+    Args:
+        crop_size (tuple): Expected size after cropping, (h, w).
+        cat_max_ratio (float): The maximum ratio that single category could
+            occupy.
+    """
+
+    def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
+        assert crop_size[0] > 0 and crop_size[1] > 0
+        self.crop_size = crop_size
+        self.cat_max_ratio = cat_max_ratio
+        self.ignore_index = ignore_index
+
+    def get_crop_bbox(self, img):
+        """Randomly get a crop bounding box."""
+        margin_h = max(img.shape[0] - self.crop_size[0], 0)
+        margin_w = max(img.shape[1] - self.crop_size[1], 0)
+        offset_h = np.random.randint(0, margin_h + 1)
+        offset_w = np.random.randint(0, margin_w + 1)
+        crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0]
+        crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
+
+        return crop_y1, crop_y2, crop_x1, crop_x2
+
+    def crop(self, img, crop_bbox):
+        """Crop from ``img``"""
+        crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
+        img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+        return img
+
+    def __call__(self, results):
+        """Call function to randomly crop images, semantic segmentation maps.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Randomly cropped results, 'img_shape' key in result dict is
+                updated according to crop size.
+        """
+
+        img = results['img']
+        crop_bbox = self.get_crop_bbox(img)
+        if self.cat_max_ratio < 1.:
+            # Repeat 10 times
+            for _ in range(10):
+                seg_temp = self.crop(results['gt_semantic_seg'], crop_bbox)
+                labels, cnt = np.unique(seg_temp, return_counts=True)
+                cnt = cnt[labels != self.ignore_index]
+                if len(cnt) > 1 and np.max(cnt) / np.sum(
+                        cnt) < self.cat_max_ratio:
+                    break
+                crop_bbox = self.get_crop_bbox(img)
+
+        # crop the image
+        img = self.crop(img, crop_bbox)
+        img_shape = img.shape
+        results['img'] = img
+        results['img_shape'] = img_shape
+
+        # crop semantic seg
+        for key in results.get('seg_fields', []):
+            results[key] = self.crop(results[key], crop_bbox)
+
+        return results
+
+    def __repr__(self):
+        return self.__class__.__name__ + f'(crop_size={self.crop_size})'
+
+
+@PIPELINES.register_module()
+class RandomRotate(object):
+    """Rotate the image & seg.
+
+    Args:
+        prob (float): The rotation probability.
+        degree (float, tuple[float]): Range of degrees to select from. If
+            degree is a number instead of tuple like (min, max),
+            the range of degree will be (``-degree``, ``+degree``)
+        pad_val (float, optional): Padding value of image. Default: 0.
+        seg_pad_val (float, optional): Padding value of segmentation map.
+            Default: 255.
+        center (tuple[float], optional): Center point (w, h) of the rotation in
+            the source image. If not specified, the center of the image will be
+            used. Default: None.
+        auto_bound (bool): Whether to adjust the image size to cover the whole
+            rotated image. Default: False
+    """
+
+    def __init__(self,
+                 prob,
+                 degree,
+                 pad_val=0,
+                 seg_pad_val=255,
+                 center=None,
+                 auto_bound=False):
+        self.prob = prob
+        assert prob >= 0 and prob <= 1
+        if isinstance(degree, (float, int)):
+            assert degree > 0, f'degree {degree} should be positive'
+            self.degree = (-degree, degree)
+        else:
+            self.degree = degree
+        assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
+                                      f'tuple of (min, max)'
+        self.pal_val = pad_val
+        self.seg_pad_val = seg_pad_val
+        self.center = center
+        self.auto_bound = auto_bound
+
+    def __call__(self, results):
+        """Call function to rotate image, semantic segmentation maps.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Rotated results.
+        """
+
+        rotate = True if np.random.rand() < self.prob else False
+        degree = np.random.uniform(min(*self.degree), max(*self.degree))
+        if rotate:
+            # rotate image
+            results['img'] = mmcv.imrotate(
+                results['img'],
+                angle=degree,
+                border_value=self.pal_val,
+                center=self.center,
+                auto_bound=self.auto_bound)
+
+            # rotate segs
+            for key in results.get('seg_fields', []):
+                results[key] = mmcv.imrotate(
+                    results[key],
+                    angle=degree,
+                    border_value=self.seg_pad_val,
+                    center=self.center,
+                    auto_bound=self.auto_bound,
+                    interpolation='nearest')
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(prob={self.prob}, ' \
+                    f'degree={self.degree}, ' \
+                    f'pad_val={self.pal_val}, ' \
+                    f'seg_pad_val={self.seg_pad_val}, ' \
+                    f'center={self.center}, ' \
+                    f'auto_bound={self.auto_bound})'
+        return repr_str
+
+
+@PIPELINES.register_module()
+class RGB2Gray(object):
+    """Convert RGB image to grayscale image.
+
+    This transform calculate the weighted mean of input image channels with
+    ``weights`` and then expand the channels to ``out_channels``. When
+    ``out_channels`` is None, the number of output channels is the same as
+    input channels.
+
+    Args:
+        out_channels (int): Expected number of output channels after
+            transforming. Default: None.
+        weights (tuple[float]): The weights to calculate the weighted mean.
+            Default: (0.299, 0.587, 0.114).
+    """
+
+    def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
+        assert out_channels is None or out_channels > 0
+        self.out_channels = out_channels
+        assert isinstance(weights, tuple)
+        for item in weights:
+            assert isinstance(item, (float, int))
+        self.weights = weights
+
+    def __call__(self, results):
+        """Call function to convert RGB image to grayscale image.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Result dict with grayscale image.
+        """
+        img = results['img']
+        assert len(img.shape) == 3
+        assert img.shape[2] == len(self.weights)
+        weights = np.array(self.weights).reshape((1, 1, -1))
+        img = (img * weights).sum(2, keepdims=True)
+        if self.out_channels is None:
+            img = img.repeat(weights.shape[2], axis=2)
+        else:
+            img = img.repeat(self.out_channels, axis=2)
+
+        results['img'] = img
+        results['img_shape'] = img.shape
+
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(out_channels={self.out_channels}, ' \
+                    f'weights={self.weights})'
+        return repr_str
+
+
+@PIPELINES.register_module()
+class AdjustGamma(object):
+    """Using gamma correction to process the image.
+
+    Args:
+        gamma (float or int): Gamma value used in gamma correction.
+            Default: 1.0.
+    """
+
+    def __init__(self, gamma=1.0):
+        assert isinstance(gamma, float) or isinstance(gamma, int)
+        assert gamma > 0
+        self.gamma = gamma
+        inv_gamma = 1.0 / gamma
+        self.table = np.array([(i / 255.0)**inv_gamma * 255
+                               for i in np.arange(256)]).astype('uint8')
+
+    def __call__(self, results):
+        """Call function to process the image with gamma correction.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Processed results.
+        """
+
+        results['img'] = mmcv.lut_transform(
+            np.array(results['img'], dtype=np.uint8), self.table)
+
+        return results
+
+    def __repr__(self):
+        return self.__class__.__name__ + f'(gamma={self.gamma})'
+
+
+@PIPELINES.register_module()
+class SegRescale(object):
+    """Rescale semantic segmentation maps.
+
+    Args:
+        scale_factor (float): The scale factor of the final output.
+    """
+
+    def __init__(self, scale_factor=1):
+        self.scale_factor = scale_factor
+
+    def __call__(self, results):
+        """Call function to scale the semantic segmentation map.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Result dict with semantic segmentation map scaled.
+        """
+        for key in results.get('seg_fields', []):
+            if self.scale_factor != 1:
+                results[key] = mmcv.imrescale(
+                    results[key], self.scale_factor, interpolation='nearest')
+        return results
+
+    def __repr__(self):
+        return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
+
+
+@PIPELINES.register_module()
+class PhotoMetricDistortion(object):
+    """Apply photometric distortion to image sequentially, every transformation
+    is applied with a probability of 0.5. The position of random contrast is in
+    second or second to last.
+
+    1. random brightness
+    2. random contrast (mode 0)
+    3. convert color from BGR to HSV
+    4. random saturation
+    5. random hue
+    6. convert color from HSV to BGR
+    7. random contrast (mode 1)
+
+    Args:
+        brightness_delta (int): delta of brightness.
+        contrast_range (tuple): range of contrast.
+        saturation_range (tuple): range of saturation.
+        hue_delta (int): delta of hue.
+    """
+
+    def __init__(self,
+                 brightness_delta=32,
+                 contrast_range=(0.5, 1.5),
+                 saturation_range=(0.5, 1.5),
+                 hue_delta=18):
+        self.brightness_delta = brightness_delta
+        self.contrast_lower, self.contrast_upper = contrast_range
+        self.saturation_lower, self.saturation_upper = saturation_range
+        self.hue_delta = hue_delta
+
+    def convert(self, img, alpha=1, beta=0):
+        """Multiple with alpha and add beat with clip."""
+        img = img.astype(np.float32) * alpha + beta
+        img = np.clip(img, 0, 255)
+        return img.astype(np.uint8)
+
+    def brightness(self, img):
+        """Brightness distortion."""
+        if random.randint(2):
+            return self.convert(
+                img,
+                beta=random.uniform(-self.brightness_delta,
+                                    self.brightness_delta))
+        return img
+
+    def contrast(self, img):
+        """Contrast distortion."""
+        if random.randint(2):
+            return self.convert(
+                img,
+                alpha=random.uniform(self.contrast_lower, self.contrast_upper))
+        return img
+
+    def saturation(self, img):
+        """Saturation distortion."""
+        if random.randint(2):
+            img = mmcv.bgr2hsv(img)
+            img[:, :, 1] = self.convert(
+                img[:, :, 1],
+                alpha=random.uniform(self.saturation_lower,
+                                     self.saturation_upper))
+            img = mmcv.hsv2bgr(img)
+        return img
+
+    def hue(self, img):
+        """Hue distortion."""
+        if random.randint(2):
+            img = mmcv.bgr2hsv(img)
+            img[:, :,
+                0] = (img[:, :, 0].astype(int) +
+                      random.randint(-self.hue_delta, self.hue_delta)) % 180
+            img = mmcv.hsv2bgr(img)
+        return img
+
+    def __call__(self, results):
+        """Call function to perform photometric distortion on images.
+
+        Args:
+            results (dict): Result dict from loading pipeline.
+
+        Returns:
+            dict: Result dict with images distorted.
+        """
+
+        img = results['img']
+        # random brightness
+        img = self.brightness(img)
+
+        # mode == 0 --> do random contrast first
+        # mode == 1 --> do random contrast last
+        mode = random.randint(2)
+        if mode == 1:
+            img = self.contrast(img)
+
+        # random saturation
+        img = self.saturation(img)
+
+        # random hue
+        img = self.hue(img)
+
+        # random contrast
+        if mode == 0:
+            img = self.contrast(img)
+
+        results['img'] = img
+        return results
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += (f'(brightness_delta={self.brightness_delta}, '
+                     f'contrast_range=({self.contrast_lower}, '
+                     f'{self.contrast_upper}), '
+                     f'saturation_range=({self.saturation_lower}, '
+                     f'{self.saturation_upper}), '
+                     f'hue_delta={self.hue_delta})')
+        return repr_str
diff --git a/annotator/uniformer/mmseg/datasets/stare.py b/annotator/uniformer/mmseg/datasets/stare.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbd14e0920e7f6a73baff1432e5a32ccfdb0dfae
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/stare.py
@@ -0,0 +1,27 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class STAREDataset(CustomDataset):
+    """STARE dataset.
+
+    In segmentation map annotation for STARE, 0 stands for background, which is
+    included in 2 categories. ``reduce_zero_label`` is fixed to False. The
+    ``img_suffix`` is fixed to '.png' and ``seg_map_suffix`` is fixed to
+    '.ah.png'.
+    """
+
+    CLASSES = ('background', 'vessel')
+
+    PALETTE = [[120, 120, 120], [6, 230, 230]]
+
+    def __init__(self, **kwargs):
+        super(STAREDataset, self).__init__(
+            img_suffix='.png',
+            seg_map_suffix='.ah.png',
+            reduce_zero_label=False,
+            **kwargs)
+        assert osp.exists(self.img_dir)
diff --git a/annotator/uniformer/mmseg/datasets/voc.py b/annotator/uniformer/mmseg/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8855203b14ee0dc4da9099a2945d4aedcffbcd6
--- /dev/null
+++ b/annotator/uniformer/mmseg/datasets/voc.py
@@ -0,0 +1,29 @@
+import os.path as osp
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class PascalVOCDataset(CustomDataset):
+    """Pascal VOC dataset.
+
+    Args:
+        split (str): Split txt file for Pascal VOC.
+    """
+
+    CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
+               'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
+               'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
+               'train', 'tvmonitor')
+
+    PALETTE = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128],
+               [128, 0, 128], [0, 128, 128], [128, 128, 128], [64, 0, 0],
+               [192, 0, 0], [64, 128, 0], [192, 128, 0], [64, 0, 128],
+               [192, 0, 128], [64, 128, 128], [192, 128, 128], [0, 64, 0],
+               [128, 64, 0], [0, 192, 0], [128, 192, 0], [0, 64, 128]]
+
+    def __init__(self, split, **kwargs):
+        super(PascalVOCDataset, self).__init__(
+            img_suffix='.jpg', seg_map_suffix='.png', split=split, **kwargs)
+        assert osp.exists(self.img_dir) and self.split is not None
diff --git a/annotator/uniformer/mmseg/models/__init__.py b/annotator/uniformer/mmseg/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cf93f8bec9cf0cef0a3bd76ca3ca92eb188f535
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/__init__.py
@@ -0,0 +1,12 @@
+from .backbones import *  # noqa: F401,F403
+from .builder import (BACKBONES, HEADS, LOSSES, SEGMENTORS, build_backbone,
+                      build_head, build_loss, build_segmentor)
+from .decode_heads import *  # noqa: F401,F403
+from .losses import *  # noqa: F401,F403
+from .necks import *  # noqa: F401,F403
+from .segmentors import *  # noqa: F401,F403
+
+__all__ = [
+    'BACKBONES', 'HEADS', 'LOSSES', 'SEGMENTORS', 'build_backbone',
+    'build_head', 'build_loss', 'build_segmentor'
+]
diff --git a/annotator/uniformer/mmseg/models/backbones/__init__.py b/annotator/uniformer/mmseg/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8339983905fb5d20bae42ba6f76fea75d278b1aa
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/__init__.py
@@ -0,0 +1,17 @@
+from .cgnet import CGNet
+# from .fast_scnn import FastSCNN
+from .hrnet import HRNet
+from .mobilenet_v2 import MobileNetV2
+from .mobilenet_v3 import MobileNetV3
+from .resnest import ResNeSt
+from .resnet import ResNet, ResNetV1c, ResNetV1d
+from .resnext import ResNeXt
+from .unet import UNet
+from .vit import VisionTransformer
+from .uniformer import UniFormer
+
+__all__ = [
+    'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet',
+    'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
+    'VisionTransformer', 'UniFormer'
+]
diff --git a/annotator/uniformer/mmseg/models/backbones/cgnet.py b/annotator/uniformer/mmseg/models/backbones/cgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8bca442c8f18179f217e40c298fb5ef39df77c4
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/cgnet.py
@@ -0,0 +1,367 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
+                      constant_init, kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+
+
+class GlobalContextExtractor(nn.Module):
+    """Global Context Extractor for CGNet.
+
+    This class is employed to refine the joint feature of both local feature
+    and surrounding context.
+
+    Args:
+        channel (int): Number of input feature channels.
+        reduction (int): Reductions for global context extractor. Default: 16.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+    """
+
+    def __init__(self, channel, reduction=16, with_cp=False):
+        super(GlobalContextExtractor, self).__init__()
+        self.channel = channel
+        self.reduction = reduction
+        assert reduction >= 1 and channel >= reduction
+        self.with_cp = with_cp
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
+            nn.Linear(channel // reduction, channel), nn.Sigmoid())
+
+    def forward(self, x):
+
+        def _inner_forward(x):
+            num_batch, num_channel = x.size()[:2]
+            y = self.avg_pool(x).view(num_batch, num_channel)
+            y = self.fc(y).view(num_batch, num_channel, 1, 1)
+            return x * y
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        return out
+
+
+class ContextGuidedBlock(nn.Module):
+    """Context Guided Block for CGNet.
+
+    This class consists of four components: local feature extractor,
+    surrounding feature extractor, joint feature extractor and global
+    context extractor.
+
+    Args:
+        in_channels (int): Number of input feature channels.
+        out_channels (int): Number of output feature channels.
+        dilation (int): Dilation rate for surrounding context extractor.
+            Default: 2.
+        reduction (int): Reduction for global context extractor. Default: 16.
+        skip_connect (bool): Add input to output or not. Default: True.
+        downsample (bool): Downsample the input to 1/2 or not. Default: False.
+        conv_cfg (dict): Config dict for convolution layer.
+            Default: None, which means using conv2d.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='BN', requires_grad=True).
+        act_cfg (dict): Config dict for activation layer.
+            Default: dict(type='PReLU').
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 dilation=2,
+                 reduction=16,
+                 skip_connect=True,
+                 downsample=False,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN', requires_grad=True),
+                 act_cfg=dict(type='PReLU'),
+                 with_cp=False):
+        super(ContextGuidedBlock, self).__init__()
+        self.with_cp = with_cp
+        self.downsample = downsample
+
+        channels = out_channels if downsample else out_channels // 2
+        if 'type' in act_cfg and act_cfg['type'] == 'PReLU':
+            act_cfg['num_parameters'] = channels
+        kernel_size = 3 if downsample else 1
+        stride = 2 if downsample else 1
+        padding = (kernel_size - 1) // 2
+
+        self.conv1x1 = ConvModule(
+            in_channels,
+            channels,
+            kernel_size,
+            stride,
+            padding,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg)
+
+        self.f_loc = build_conv_layer(
+            conv_cfg,
+            channels,
+            channels,
+            kernel_size=3,
+            padding=1,
+            groups=channels,
+            bias=False)
+        self.f_sur = build_conv_layer(
+            conv_cfg,
+            channels,
+            channels,
+            kernel_size=3,
+            padding=dilation,
+            groups=channels,
+            dilation=dilation,
+            bias=False)
+
+        self.bn = build_norm_layer(norm_cfg, 2 * channels)[1]
+        self.activate = nn.PReLU(2 * channels)
+
+        if downsample:
+            self.bottleneck = build_conv_layer(
+                conv_cfg,
+                2 * channels,
+                out_channels,
+                kernel_size=1,
+                bias=False)
+
+        self.skip_connect = skip_connect and not downsample
+        self.f_glo = GlobalContextExtractor(out_channels, reduction, with_cp)
+
+    def forward(self, x):
+
+        def _inner_forward(x):
+            out = self.conv1x1(x)
+            loc = self.f_loc(out)
+            sur = self.f_sur(out)
+
+            joi_feat = torch.cat([loc, sur], 1)  # the joint feature
+            joi_feat = self.bn(joi_feat)
+            joi_feat = self.activate(joi_feat)
+            if self.downsample:
+                joi_feat = self.bottleneck(joi_feat)  # channel = out_channels
+            # f_glo is employed to refine the joint feature
+            out = self.f_glo(joi_feat)
+
+            if self.skip_connect:
+                return x + out
+            else:
+                return out
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        return out
+
+
+class InputInjection(nn.Module):
+    """Downsampling module for CGNet."""
+
+    def __init__(self, num_downsampling):
+        super(InputInjection, self).__init__()
+        self.pool = nn.ModuleList()
+        for i in range(num_downsampling):
+            self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
+
+    def forward(self, x):
+        for pool in self.pool:
+            x = pool(x)
+        return x
+
+
+@BACKBONES.register_module()
+class CGNet(nn.Module):
+    """CGNet backbone.
+
+    A Light-weight Context Guided Network for Semantic Segmentation
+    arXiv: https://arxiv.org/abs/1811.08201
+
+    Args:
+        in_channels (int): Number of input image channels. Normally 3.
+        num_channels (tuple[int]): Numbers of feature channels at each stages.
+            Default: (32, 64, 128).
+        num_blocks (tuple[int]): Numbers of CG blocks at stage 1 and stage 2.
+            Default: (3, 21).
+        dilations (tuple[int]): Dilation rate for surrounding context
+            extractors at stage 1 and stage 2. Default: (2, 4).
+        reductions (tuple[int]): Reductions for global context extractors at
+            stage 1 and stage 2. Default: (8, 16).
+        conv_cfg (dict): Config dict for convolution layer.
+            Default: None, which means using conv2d.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='BN', requires_grad=True).
+        act_cfg (dict): Config dict for activation layer.
+            Default: dict(type='PReLU').
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only. Default: False.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+    """
+
+    def __init__(self,
+                 in_channels=3,
+                 num_channels=(32, 64, 128),
+                 num_blocks=(3, 21),
+                 dilations=(2, 4),
+                 reductions=(8, 16),
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN', requires_grad=True),
+                 act_cfg=dict(type='PReLU'),
+                 norm_eval=False,
+                 with_cp=False):
+
+        super(CGNet, self).__init__()
+        self.in_channels = in_channels
+        self.num_channels = num_channels
+        assert isinstance(self.num_channels, tuple) and len(
+            self.num_channels) == 3
+        self.num_blocks = num_blocks
+        assert isinstance(self.num_blocks, tuple) and len(self.num_blocks) == 2
+        self.dilations = dilations
+        assert isinstance(self.dilations, tuple) and len(self.dilations) == 2
+        self.reductions = reductions
+        assert isinstance(self.reductions, tuple) and len(self.reductions) == 2
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        if 'type' in self.act_cfg and self.act_cfg['type'] == 'PReLU':
+            self.act_cfg['num_parameters'] = num_channels[0]
+        self.norm_eval = norm_eval
+        self.with_cp = with_cp
+
+        cur_channels = in_channels
+        self.stem = nn.ModuleList()
+        for i in range(3):
+            self.stem.append(
+                ConvModule(
+                    cur_channels,
+                    num_channels[0],
+                    3,
+                    2 if i == 0 else 1,
+                    padding=1,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+            cur_channels = num_channels[0]
+
+        self.inject_2x = InputInjection(1)  # down-sample for Input, factor=2
+        self.inject_4x = InputInjection(2)  # down-sample for Input, factor=4
+
+        cur_channels += in_channels
+        self.norm_prelu_0 = nn.Sequential(
+            build_norm_layer(norm_cfg, cur_channels)[1],
+            nn.PReLU(cur_channels))
+
+        # stage 1
+        self.level1 = nn.ModuleList()
+        for i in range(num_blocks[0]):
+            self.level1.append(
+                ContextGuidedBlock(
+                    cur_channels if i == 0 else num_channels[1],
+                    num_channels[1],
+                    dilations[0],
+                    reductions[0],
+                    downsample=(i == 0),
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg,
+                    with_cp=with_cp))  # CG block
+
+        cur_channels = 2 * num_channels[1] + in_channels
+        self.norm_prelu_1 = nn.Sequential(
+            build_norm_layer(norm_cfg, cur_channels)[1],
+            nn.PReLU(cur_channels))
+
+        # stage 2
+        self.level2 = nn.ModuleList()
+        for i in range(num_blocks[1]):
+            self.level2.append(
+                ContextGuidedBlock(
+                    cur_channels if i == 0 else num_channels[2],
+                    num_channels[2],
+                    dilations[1],
+                    reductions[1],
+                    downsample=(i == 0),
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg,
+                    with_cp=with_cp))  # CG block
+
+        cur_channels = 2 * num_channels[2]
+        self.norm_prelu_2 = nn.Sequential(
+            build_norm_layer(norm_cfg, cur_channels)[1],
+            nn.PReLU(cur_channels))
+
+    def forward(self, x):
+        output = []
+
+        # stage 0
+        inp_2x = self.inject_2x(x)
+        inp_4x = self.inject_4x(x)
+        for layer in self.stem:
+            x = layer(x)
+        x = self.norm_prelu_0(torch.cat([x, inp_2x], 1))
+        output.append(x)
+
+        # stage 1
+        for i, layer in enumerate(self.level1):
+            x = layer(x)
+            if i == 0:
+                down1 = x
+        x = self.norm_prelu_1(torch.cat([x, down1, inp_4x], 1))
+        output.append(x)
+
+        # stage 2
+        for i, layer in enumerate(self.level2):
+            x = layer(x)
+            if i == 0:
+                down2 = x
+        x = self.norm_prelu_2(torch.cat([down2, x], 1))
+        output.append(x)
+
+        return output
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        if isinstance(pretrained, str):
+            logger = get_root_logger()
+            load_checkpoint(self, pretrained, strict=False, logger=logger)
+        elif pretrained is None:
+            for m in self.modules():
+                if isinstance(m, (nn.Conv2d, nn.Linear)):
+                    kaiming_init(m)
+                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+                    constant_init(m, 1)
+                elif isinstance(m, nn.PReLU):
+                    constant_init(m, 0)
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def train(self, mode=True):
+        """Convert the model into training mode will keeping the normalization
+        layer freezed."""
+        super(CGNet, self).train(mode)
+        if mode and self.norm_eval:
+            for m in self.modules():
+                # trick: eval have effect on BatchNorm only
+                if isinstance(m, _BatchNorm):
+                    m.eval()
diff --git a/annotator/uniformer/mmseg/models/backbones/fast_scnn.py b/annotator/uniformer/mmseg/models/backbones/fast_scnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..38c2350177cbc2066f45add568d30eb6041f74f3
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/fast_scnn.py
@@ -0,0 +1,375 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
+                      kaiming_init)
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from annotator.uniformer.mmseg.models.decode_heads.psp_head import PPM
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import BACKBONES
+from ..utils.inverted_residual import InvertedResidual
+
+
+class LearningToDownsample(nn.Module):
+    """Learning to downsample module.
+
+    Args:
+        in_channels (int): Number of input channels.
+        dw_channels (tuple[int]): Number of output channels of the first and
+            the second depthwise conv (dwconv) layers.
+        out_channels (int): Number of output channels of the whole
+            'learning to downsample' module.
+        conv_cfg (dict | None): Config of conv layers. Default: None
+        norm_cfg (dict | None): Config of norm layers. Default:
+            dict(type='BN')
+        act_cfg (dict): Config of activation layers. Default:
+            dict(type='ReLU')
+    """
+
+    def __init__(self,
+                 in_channels,
+                 dw_channels,
+                 out_channels,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU')):
+        super(LearningToDownsample, self).__init__()
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        dw_channels1 = dw_channels[0]
+        dw_channels2 = dw_channels[1]
+
+        self.conv = ConvModule(
+            in_channels,
+            dw_channels1,
+            3,
+            stride=2,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.dsconv1 = DepthwiseSeparableConvModule(
+            dw_channels1,
+            dw_channels2,
+            kernel_size=3,
+            stride=2,
+            padding=1,
+            norm_cfg=self.norm_cfg)
+        self.dsconv2 = DepthwiseSeparableConvModule(
+            dw_channels2,
+            out_channels,
+            kernel_size=3,
+            stride=2,
+            padding=1,
+            norm_cfg=self.norm_cfg)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.dsconv1(x)
+        x = self.dsconv2(x)
+        return x
+
+
+class GlobalFeatureExtractor(nn.Module):
+    """Global feature extractor module.
+
+    Args:
+        in_channels (int): Number of input channels of the GFE module.
+            Default: 64
+        block_channels (tuple[int]): Tuple of ints. Each int specifies the
+            number of output channels of each Inverted Residual module.
+            Default: (64, 96, 128)
+        out_channels(int): Number of output channels of the GFE module.
+            Default: 128
+        expand_ratio (int): Adjusts number of channels of the hidden layer
+            in InvertedResidual by this amount.
+            Default: 6
+        num_blocks (tuple[int]): Tuple of ints. Each int specifies the
+            number of times each Inverted Residual module is repeated.
+            The repeated Inverted Residual modules are called a 'group'.
+            Default: (3, 3, 3)
+        strides (tuple[int]): Tuple of ints. Each int specifies
+            the downsampling factor of each 'group'.
+            Default: (2, 2, 1)
+        pool_scales (tuple[int]): Tuple of ints. Each int specifies
+            the parameter required in 'global average pooling' within PPM.
+            Default: (1, 2, 3, 6)
+        conv_cfg (dict | None): Config of conv layers. Default: None
+        norm_cfg (dict | None): Config of norm layers. Default:
+            dict(type='BN')
+        act_cfg (dict): Config of activation layers. Default:
+            dict(type='ReLU')
+        align_corners (bool): align_corners argument of F.interpolate.
+            Default: False
+    """
+
+    def __init__(self,
+                 in_channels=64,
+                 block_channels=(64, 96, 128),
+                 out_channels=128,
+                 expand_ratio=6,
+                 num_blocks=(3, 3, 3),
+                 strides=(2, 2, 1),
+                 pool_scales=(1, 2, 3, 6),
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 align_corners=False):
+        super(GlobalFeatureExtractor, self).__init__()
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        assert len(block_channels) == len(num_blocks) == 3
+        self.bottleneck1 = self._make_layer(in_channels, block_channels[0],
+                                            num_blocks[0], strides[0],
+                                            expand_ratio)
+        self.bottleneck2 = self._make_layer(block_channels[0],
+                                            block_channels[1], num_blocks[1],
+                                            strides[1], expand_ratio)
+        self.bottleneck3 = self._make_layer(block_channels[1],
+                                            block_channels[2], num_blocks[2],
+                                            strides[2], expand_ratio)
+        self.ppm = PPM(
+            pool_scales,
+            block_channels[2],
+            block_channels[2] // 4,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg,
+            align_corners=align_corners)
+        self.out = ConvModule(
+            block_channels[2] * 2,
+            out_channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def _make_layer(self,
+                    in_channels,
+                    out_channels,
+                    blocks,
+                    stride=1,
+                    expand_ratio=6):
+        layers = [
+            InvertedResidual(
+                in_channels,
+                out_channels,
+                stride,
+                expand_ratio,
+                norm_cfg=self.norm_cfg)
+        ]
+        for i in range(1, blocks):
+            layers.append(
+                InvertedResidual(
+                    out_channels,
+                    out_channels,
+                    1,
+                    expand_ratio,
+                    norm_cfg=self.norm_cfg))
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.bottleneck1(x)
+        x = self.bottleneck2(x)
+        x = self.bottleneck3(x)
+        x = torch.cat([x, *self.ppm(x)], dim=1)
+        x = self.out(x)
+        return x
+
+
+class FeatureFusionModule(nn.Module):
+    """Feature fusion module.
+
+    Args:
+        higher_in_channels (int): Number of input channels of the
+            higher-resolution branch.
+        lower_in_channels (int): Number of input channels of the
+            lower-resolution branch.
+        out_channels (int): Number of output channels.
+        conv_cfg (dict | None): Config of conv layers. Default: None
+        norm_cfg (dict | None): Config of norm layers. Default:
+            dict(type='BN')
+        act_cfg (dict): Config of activation layers. Default:
+            dict(type='ReLU')
+        align_corners (bool): align_corners argument of F.interpolate.
+            Default: False
+    """
+
+    def __init__(self,
+                 higher_in_channels,
+                 lower_in_channels,
+                 out_channels,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 align_corners=False):
+        super(FeatureFusionModule, self).__init__()
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        self.align_corners = align_corners
+        self.dwconv = ConvModule(
+            lower_in_channels,
+            out_channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.conv_lower_res = ConvModule(
+            out_channels,
+            out_channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=None)
+        self.conv_higher_res = ConvModule(
+            higher_in_channels,
+            out_channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=None)
+        self.relu = nn.ReLU(True)
+
+    def forward(self, higher_res_feature, lower_res_feature):
+        lower_res_feature = resize(
+            lower_res_feature,
+            size=higher_res_feature.size()[2:],
+            mode='bilinear',
+            align_corners=self.align_corners)
+        lower_res_feature = self.dwconv(lower_res_feature)
+        lower_res_feature = self.conv_lower_res(lower_res_feature)
+
+        higher_res_feature = self.conv_higher_res(higher_res_feature)
+        out = higher_res_feature + lower_res_feature
+        return self.relu(out)
+
+
+@BACKBONES.register_module()
+class FastSCNN(nn.Module):
+    """Fast-SCNN Backbone.
+
+    Args:
+        in_channels (int): Number of input image channels. Default: 3.
+        downsample_dw_channels (tuple[int]): Number of output channels after
+            the first conv layer & the second conv layer in
+            Learning-To-Downsample (LTD) module.
+            Default: (32, 48).
+        global_in_channels (int): Number of input channels of
+            Global Feature Extractor(GFE).
+            Equal to number of output channels of LTD.
+            Default: 64.
+        global_block_channels (tuple[int]): Tuple of integers that describe
+            the output channels for each of the MobileNet-v2 bottleneck
+            residual blocks in GFE.
+            Default: (64, 96, 128).
+        global_block_strides (tuple[int]): Tuple of integers
+            that describe the strides (downsampling factors) for each of the
+            MobileNet-v2 bottleneck residual blocks in GFE.
+            Default: (2, 2, 1).
+        global_out_channels (int): Number of output channels of GFE.
+            Default: 128.
+        higher_in_channels (int): Number of input channels of the higher
+            resolution branch in FFM.
+            Equal to global_in_channels.
+            Default: 64.
+        lower_in_channels (int): Number of input channels of  the lower
+            resolution branch in FFM.
+            Equal to global_out_channels.
+            Default: 128.
+        fusion_out_channels (int): Number of output channels of FFM.
+            Default: 128.
+        out_indices (tuple): Tuple of indices of list
+            [higher_res_features, lower_res_features, fusion_output].
+            Often set to (0,1,2) to enable aux. heads.
+            Default: (0, 1, 2).
+        conv_cfg (dict | None): Config of conv layers. Default: None
+        norm_cfg (dict | None): Config of norm layers. Default:
+            dict(type='BN')
+        act_cfg (dict): Config of activation layers. Default:
+            dict(type='ReLU')
+        align_corners (bool): align_corners argument of F.interpolate.
+            Default: False
+    """
+
+    def __init__(self,
+                 in_channels=3,
+                 downsample_dw_channels=(32, 48),
+                 global_in_channels=64,
+                 global_block_channels=(64, 96, 128),
+                 global_block_strides=(2, 2, 1),
+                 global_out_channels=128,
+                 higher_in_channels=64,
+                 lower_in_channels=128,
+                 fusion_out_channels=128,
+                 out_indices=(0, 1, 2),
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 align_corners=False):
+
+        super(FastSCNN, self).__init__()
+        if global_in_channels != higher_in_channels:
+            raise AssertionError('Global Input Channels must be the same \
+                                 with Higher Input Channels!')
+        elif global_out_channels != lower_in_channels:
+            raise AssertionError('Global Output Channels must be the same \
+                                with Lower Input Channels!')
+
+        self.in_channels = in_channels
+        self.downsample_dw_channels1 = downsample_dw_channels[0]
+        self.downsample_dw_channels2 = downsample_dw_channels[1]
+        self.global_in_channels = global_in_channels
+        self.global_block_channels = global_block_channels
+        self.global_block_strides = global_block_strides
+        self.global_out_channels = global_out_channels
+        self.higher_in_channels = higher_in_channels
+        self.lower_in_channels = lower_in_channels
+        self.fusion_out_channels = fusion_out_channels
+        self.out_indices = out_indices
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        self.align_corners = align_corners
+        self.learning_to_downsample = LearningToDownsample(
+            in_channels,
+            downsample_dw_channels,
+            global_in_channels,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.global_feature_extractor = GlobalFeatureExtractor(
+            global_in_channels,
+            global_block_channels,
+            global_out_channels,
+            strides=self.global_block_strides,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg,
+            align_corners=self.align_corners)
+        self.feature_fusion = FeatureFusionModule(
+            higher_in_channels,
+            lower_in_channels,
+            fusion_out_channels,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg,
+            align_corners=self.align_corners)
+
+    def init_weights(self, pretrained=None):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                kaiming_init(m)
+            elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+                constant_init(m, 1)
+
+    def forward(self, x):
+        higher_res_features = self.learning_to_downsample(x)
+        lower_res_features = self.global_feature_extractor(higher_res_features)
+        fusion_output = self.feature_fusion(higher_res_features,
+                                            lower_res_features)
+
+        outs = [higher_res_features, lower_res_features, fusion_output]
+        outs = [outs[i] for i in self.out_indices]
+        return tuple(outs)
diff --git a/annotator/uniformer/mmseg/models/backbones/hrnet.py b/annotator/uniformer/mmseg/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..331ebf3ccb8597b3f507670753789073fc3c946d
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/hrnet.py
@@ -0,0 +1,555 @@
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
+                      kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.ops import Upsample, resize
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from .resnet import BasicBlock, Bottleneck
+
+
+class HRModule(nn.Module):
+    """High-Resolution Module for HRNet.
+
+    In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+    is in this module.
+    """
+
+    def __init__(self,
+                 num_branches,
+                 blocks,
+                 num_blocks,
+                 in_channels,
+                 num_channels,
+                 multiscale_output=True,
+                 with_cp=False,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN', requires_grad=True)):
+        super(HRModule, self).__init__()
+        self._check_branches(num_branches, num_blocks, in_channels,
+                             num_channels)
+
+        self.in_channels = in_channels
+        self.num_branches = num_branches
+
+        self.multiscale_output = multiscale_output
+        self.norm_cfg = norm_cfg
+        self.conv_cfg = conv_cfg
+        self.with_cp = with_cp
+        self.branches = self._make_branches(num_branches, blocks, num_blocks,
+                                            num_channels)
+        self.fuse_layers = self._make_fuse_layers()
+        self.relu = nn.ReLU(inplace=False)
+
+    def _check_branches(self, num_branches, num_blocks, in_channels,
+                        num_channels):
+        """Check branches configuration."""
+        if num_branches != len(num_blocks):
+            error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_BLOCKS(' \
+                        f'{len(num_blocks)})'
+            raise ValueError(error_msg)
+
+        if num_branches != len(num_channels):
+            error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_CHANNELS(' \
+                        f'{len(num_channels)})'
+            raise ValueError(error_msg)
+
+        if num_branches != len(in_channels):
+            error_msg = f'NUM_BRANCHES({num_branches}) <> NUM_INCHANNELS(' \
+                        f'{len(in_channels)})'
+            raise ValueError(error_msg)
+
+    def _make_one_branch(self,
+                         branch_index,
+                         block,
+                         num_blocks,
+                         num_channels,
+                         stride=1):
+        """Build one branch."""
+        downsample = None
+        if stride != 1 or \
+                self.in_channels[branch_index] != \
+                num_channels[branch_index] * block.expansion:
+            downsample = nn.Sequential(
+                build_conv_layer(
+                    self.conv_cfg,
+                    self.in_channels[branch_index],
+                    num_channels[branch_index] * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                    bias=False),
+                build_norm_layer(self.norm_cfg, num_channels[branch_index] *
+                                 block.expansion)[1])
+
+        layers = []
+        layers.append(
+            block(
+                self.in_channels[branch_index],
+                num_channels[branch_index],
+                stride,
+                downsample=downsample,
+                with_cp=self.with_cp,
+                norm_cfg=self.norm_cfg,
+                conv_cfg=self.conv_cfg))
+        self.in_channels[branch_index] = \
+            num_channels[branch_index] * block.expansion
+        for i in range(1, num_blocks[branch_index]):
+            layers.append(
+                block(
+                    self.in_channels[branch_index],
+                    num_channels[branch_index],
+                    with_cp=self.with_cp,
+                    norm_cfg=self.norm_cfg,
+                    conv_cfg=self.conv_cfg))
+
+        return nn.Sequential(*layers)
+
+    def _make_branches(self, num_branches, block, num_blocks, num_channels):
+        """Build multiple branch."""
+        branches = []
+
+        for i in range(num_branches):
+            branches.append(
+                self._make_one_branch(i, block, num_blocks, num_channels))
+
+        return nn.ModuleList(branches)
+
+    def _make_fuse_layers(self):
+        """Build fuse layer."""
+        if self.num_branches == 1:
+            return None
+
+        num_branches = self.num_branches
+        in_channels = self.in_channels
+        fuse_layers = []
+        num_out_branches = num_branches if self.multiscale_output else 1
+        for i in range(num_out_branches):
+            fuse_layer = []
+            for j in range(num_branches):
+                if j > i:
+                    fuse_layer.append(
+                        nn.Sequential(
+                            build_conv_layer(
+                                self.conv_cfg,
+                                in_channels[j],
+                                in_channels[i],
+                                kernel_size=1,
+                                stride=1,
+                                padding=0,
+                                bias=False),
+                            build_norm_layer(self.norm_cfg, in_channels[i])[1],
+                            # we set align_corners=False for HRNet
+                            Upsample(
+                                scale_factor=2**(j - i),
+                                mode='bilinear',
+                                align_corners=False)))
+                elif j == i:
+                    fuse_layer.append(None)
+                else:
+                    conv_downsamples = []
+                    for k in range(i - j):
+                        if k == i - j - 1:
+                            conv_downsamples.append(
+                                nn.Sequential(
+                                    build_conv_layer(
+                                        self.conv_cfg,
+                                        in_channels[j],
+                                        in_channels[i],
+                                        kernel_size=3,
+                                        stride=2,
+                                        padding=1,
+                                        bias=False),
+                                    build_norm_layer(self.norm_cfg,
+                                                     in_channels[i])[1]))
+                        else:
+                            conv_downsamples.append(
+                                nn.Sequential(
+                                    build_conv_layer(
+                                        self.conv_cfg,
+                                        in_channels[j],
+                                        in_channels[j],
+                                        kernel_size=3,
+                                        stride=2,
+                                        padding=1,
+                                        bias=False),
+                                    build_norm_layer(self.norm_cfg,
+                                                     in_channels[j])[1],
+                                    nn.ReLU(inplace=False)))
+                    fuse_layer.append(nn.Sequential(*conv_downsamples))
+            fuse_layers.append(nn.ModuleList(fuse_layer))
+
+        return nn.ModuleList(fuse_layers)
+
+    def forward(self, x):
+        """Forward function."""
+        if self.num_branches == 1:
+            return [self.branches[0](x[0])]
+
+        for i in range(self.num_branches):
+            x[i] = self.branches[i](x[i])
+
+        x_fuse = []
+        for i in range(len(self.fuse_layers)):
+            y = 0
+            for j in range(self.num_branches):
+                if i == j:
+                    y += x[j]
+                elif j > i:
+                    y = y + resize(
+                        self.fuse_layers[i][j](x[j]),
+                        size=x[i].shape[2:],
+                        mode='bilinear',
+                        align_corners=False)
+                else:
+                    y += self.fuse_layers[i][j](x[j])
+            x_fuse.append(self.relu(y))
+        return x_fuse
+
+
+@BACKBONES.register_module()
+class HRNet(nn.Module):
+    """HRNet backbone.
+
+    High-Resolution Representations for Labeling Pixels and Regions
+    arXiv: https://arxiv.org/abs/1904.04514
+
+    Args:
+        extra (dict): detailed configuration for each stage of HRNet.
+        in_channels (int): Number of input image channels. Normally 3.
+        conv_cfg (dict): dictionary to construct and config conv layer.
+        norm_cfg (dict): dictionary to construct and config norm layer.
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed.
+        zero_init_residual (bool): whether to use zero init for last norm layer
+            in resblocks to let them behave as identity.
+
+    Example:
+        >>> from annotator.uniformer.mmseg.models import HRNet
+        >>> import torch
+        >>> extra = dict(
+        >>>     stage1=dict(
+        >>>         num_modules=1,
+        >>>         num_branches=1,
+        >>>         block='BOTTLENECK',
+        >>>         num_blocks=(4, ),
+        >>>         num_channels=(64, )),
+        >>>     stage2=dict(
+        >>>         num_modules=1,
+        >>>         num_branches=2,
+        >>>         block='BASIC',
+        >>>         num_blocks=(4, 4),
+        >>>         num_channels=(32, 64)),
+        >>>     stage3=dict(
+        >>>         num_modules=4,
+        >>>         num_branches=3,
+        >>>         block='BASIC',
+        >>>         num_blocks=(4, 4, 4),
+        >>>         num_channels=(32, 64, 128)),
+        >>>     stage4=dict(
+        >>>         num_modules=3,
+        >>>         num_branches=4,
+        >>>         block='BASIC',
+        >>>         num_blocks=(4, 4, 4, 4),
+        >>>         num_channels=(32, 64, 128, 256)))
+        >>> self = HRNet(extra, in_channels=1)
+        >>> self.eval()
+        >>> inputs = torch.rand(1, 1, 32, 32)
+        >>> level_outputs = self.forward(inputs)
+        >>> for level_out in level_outputs:
+        ...     print(tuple(level_out.shape))
+        (1, 32, 8, 8)
+        (1, 64, 4, 4)
+        (1, 128, 2, 2)
+        (1, 256, 1, 1)
+    """
+
+    blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+
+    def __init__(self,
+                 extra,
+                 in_channels=3,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN', requires_grad=True),
+                 norm_eval=False,
+                 with_cp=False,
+                 zero_init_residual=False):
+        super(HRNet, self).__init__()
+        self.extra = extra
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.norm_eval = norm_eval
+        self.with_cp = with_cp
+        self.zero_init_residual = zero_init_residual
+
+        # stem net
+        self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+        self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+
+        self.conv1 = build_conv_layer(
+            self.conv_cfg,
+            in_channels,
+            64,
+            kernel_size=3,
+            stride=2,
+            padding=1,
+            bias=False)
+
+        self.add_module(self.norm1_name, norm1)
+        self.conv2 = build_conv_layer(
+            self.conv_cfg,
+            64,
+            64,
+            kernel_size=3,
+            stride=2,
+            padding=1,
+            bias=False)
+
+        self.add_module(self.norm2_name, norm2)
+        self.relu = nn.ReLU(inplace=True)
+
+        # stage 1
+        self.stage1_cfg = self.extra['stage1']
+        num_channels = self.stage1_cfg['num_channels'][0]
+        block_type = self.stage1_cfg['block']
+        num_blocks = self.stage1_cfg['num_blocks'][0]
+
+        block = self.blocks_dict[block_type]
+        stage1_out_channels = num_channels * block.expansion
+        self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+
+        # stage 2
+        self.stage2_cfg = self.extra['stage2']
+        num_channels = self.stage2_cfg['num_channels']
+        block_type = self.stage2_cfg['block']
+
+        block = self.blocks_dict[block_type]
+        num_channels = [channel * block.expansion for channel in num_channels]
+        self.transition1 = self._make_transition_layer([stage1_out_channels],
+                                                       num_channels)
+        self.stage2, pre_stage_channels = self._make_stage(
+            self.stage2_cfg, num_channels)
+
+        # stage 3
+        self.stage3_cfg = self.extra['stage3']
+        num_channels = self.stage3_cfg['num_channels']
+        block_type = self.stage3_cfg['block']
+
+        block = self.blocks_dict[block_type]
+        num_channels = [channel * block.expansion for channel in num_channels]
+        self.transition2 = self._make_transition_layer(pre_stage_channels,
+                                                       num_channels)
+        self.stage3, pre_stage_channels = self._make_stage(
+            self.stage3_cfg, num_channels)
+
+        # stage 4
+        self.stage4_cfg = self.extra['stage4']
+        num_channels = self.stage4_cfg['num_channels']
+        block_type = self.stage4_cfg['block']
+
+        block = self.blocks_dict[block_type]
+        num_channels = [channel * block.expansion for channel in num_channels]
+        self.transition3 = self._make_transition_layer(pre_stage_channels,
+                                                       num_channels)
+        self.stage4, pre_stage_channels = self._make_stage(
+            self.stage4_cfg, num_channels)
+
+    @property
+    def norm1(self):
+        """nn.Module: the normalization layer named "norm1" """
+        return getattr(self, self.norm1_name)
+
+    @property
+    def norm2(self):
+        """nn.Module: the normalization layer named "norm2" """
+        return getattr(self, self.norm2_name)
+
+    def _make_transition_layer(self, num_channels_pre_layer,
+                               num_channels_cur_layer):
+        """Make transition layer."""
+        num_branches_cur = len(num_channels_cur_layer)
+        num_branches_pre = len(num_channels_pre_layer)
+
+        transition_layers = []
+        for i in range(num_branches_cur):
+            if i < num_branches_pre:
+                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+                    transition_layers.append(
+                        nn.Sequential(
+                            build_conv_layer(
+                                self.conv_cfg,
+                                num_channels_pre_layer[i],
+                                num_channels_cur_layer[i],
+                                kernel_size=3,
+                                stride=1,
+                                padding=1,
+                                bias=False),
+                            build_norm_layer(self.norm_cfg,
+                                             num_channels_cur_layer[i])[1],
+                            nn.ReLU(inplace=True)))
+                else:
+                    transition_layers.append(None)
+            else:
+                conv_downsamples = []
+                for j in range(i + 1 - num_branches_pre):
+                    in_channels = num_channels_pre_layer[-1]
+                    out_channels = num_channels_cur_layer[i] \
+                        if j == i - num_branches_pre else in_channels
+                    conv_downsamples.append(
+                        nn.Sequential(
+                            build_conv_layer(
+                                self.conv_cfg,
+                                in_channels,
+                                out_channels,
+                                kernel_size=3,
+                                stride=2,
+                                padding=1,
+                                bias=False),
+                            build_norm_layer(self.norm_cfg, out_channels)[1],
+                            nn.ReLU(inplace=True)))
+                transition_layers.append(nn.Sequential(*conv_downsamples))
+
+        return nn.ModuleList(transition_layers)
+
+    def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+        """Make each layer."""
+        downsample = None
+        if stride != 1 or inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                build_conv_layer(
+                    self.conv_cfg,
+                    inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=stride,
+                    bias=False),
+                build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+
+        layers = []
+        layers.append(
+            block(
+                inplanes,
+                planes,
+                stride,
+                downsample=downsample,
+                with_cp=self.with_cp,
+                norm_cfg=self.norm_cfg,
+                conv_cfg=self.conv_cfg))
+        inplanes = planes * block.expansion
+        for i in range(1, blocks):
+            layers.append(
+                block(
+                    inplanes,
+                    planes,
+                    with_cp=self.with_cp,
+                    norm_cfg=self.norm_cfg,
+                    conv_cfg=self.conv_cfg))
+
+        return nn.Sequential(*layers)
+
+    def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+        """Make each stage."""
+        num_modules = layer_config['num_modules']
+        num_branches = layer_config['num_branches']
+        num_blocks = layer_config['num_blocks']
+        num_channels = layer_config['num_channels']
+        block = self.blocks_dict[layer_config['block']]
+
+        hr_modules = []
+        for i in range(num_modules):
+            # multi_scale_output is only used for the last module
+            if not multiscale_output and i == num_modules - 1:
+                reset_multiscale_output = False
+            else:
+                reset_multiscale_output = True
+
+            hr_modules.append(
+                HRModule(
+                    num_branches,
+                    block,
+                    num_blocks,
+                    in_channels,
+                    num_channels,
+                    reset_multiscale_output,
+                    with_cp=self.with_cp,
+                    norm_cfg=self.norm_cfg,
+                    conv_cfg=self.conv_cfg))
+
+        return nn.Sequential(*hr_modules), in_channels
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        if isinstance(pretrained, str):
+            logger = get_root_logger()
+            load_checkpoint(self, pretrained, strict=False, logger=logger)
+        elif pretrained is None:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d):
+                    kaiming_init(m)
+                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+                    constant_init(m, 1)
+
+            if self.zero_init_residual:
+                for m in self.modules():
+                    if isinstance(m, Bottleneck):
+                        constant_init(m.norm3, 0)
+                    elif isinstance(m, BasicBlock):
+                        constant_init(m.norm2, 0)
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def forward(self, x):
+        """Forward function."""
+
+        x = self.conv1(x)
+        x = self.norm1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        x = self.norm2(x)
+        x = self.relu(x)
+        x = self.layer1(x)
+
+        x_list = []
+        for i in range(self.stage2_cfg['num_branches']):
+            if self.transition1[i] is not None:
+                x_list.append(self.transition1[i](x))
+            else:
+                x_list.append(x)
+        y_list = self.stage2(x_list)
+
+        x_list = []
+        for i in range(self.stage3_cfg['num_branches']):
+            if self.transition2[i] is not None:
+                x_list.append(self.transition2[i](y_list[-1]))
+            else:
+                x_list.append(y_list[i])
+        y_list = self.stage3(x_list)
+
+        x_list = []
+        for i in range(self.stage4_cfg['num_branches']):
+            if self.transition3[i] is not None:
+                x_list.append(self.transition3[i](y_list[-1]))
+            else:
+                x_list.append(y_list[i])
+        y_list = self.stage4(x_list)
+
+        return y_list
+
+    def train(self, mode=True):
+        """Convert the model into training mode will keeping the normalization
+        layer freezed."""
+        super(HRNet, self).train(mode)
+        if mode and self.norm_eval:
+            for m in self.modules():
+                # trick: eval have effect on BatchNorm only
+                if isinstance(m, _BatchNorm):
+                    m.eval()
diff --git a/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py b/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab6b3791692a0d1b5da3601875711710b7bd01ba
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/mobilenet_v2.py
@@ -0,0 +1,180 @@
+import logging
+
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidual, make_divisible
+
+
+@BACKBONES.register_module()
+class MobileNetV2(nn.Module):
+    """MobileNetV2 backbone.
+
+    Args:
+        widen_factor (float): Width multiplier, multiply number of
+            channels in each layer by this amount. Default: 1.0.
+        strides (Sequence[int], optional): Strides of the first block of each
+            layer. If not specified, default config in ``arch_setting`` will
+            be used.
+        dilations (Sequence[int]): Dilation of each layer.
+        out_indices (None or Sequence[int]): Output from which stages.
+            Default: (7, ).
+        frozen_stages (int): Stages to be frozen (all param fixed).
+            Default: -1, which means not freezing any parameters.
+        conv_cfg (dict): Config dict for convolution layer.
+            Default: None, which means using conv2d.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='BN').
+        act_cfg (dict): Config dict for activation layer.
+            Default: dict(type='ReLU6').
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only. Default: False.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+    """
+
+    # Parameters to build layers. 3 parameters are needed to construct a
+    # layer, from left to right: expand_ratio, channel, num_blocks.
+    arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
+                     [6, 96, 3], [6, 160, 3], [6, 320, 1]]
+
+    def __init__(self,
+                 widen_factor=1.,
+                 strides=(1, 2, 2, 2, 1, 2, 1),
+                 dilations=(1, 1, 1, 1, 1, 1, 1),
+                 out_indices=(1, 2, 4, 6),
+                 frozen_stages=-1,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU6'),
+                 norm_eval=False,
+                 with_cp=False):
+        super(MobileNetV2, self).__init__()
+        self.widen_factor = widen_factor
+        self.strides = strides
+        self.dilations = dilations
+        assert len(strides) == len(dilations) == len(self.arch_settings)
+        self.out_indices = out_indices
+        for index in out_indices:
+            if index not in range(0, 7):
+                raise ValueError('the item in out_indices must in '
+                                 f'range(0, 8). But received {index}')
+
+        if frozen_stages not in range(-1, 7):
+            raise ValueError('frozen_stages must be in range(-1, 7). '
+                             f'But received {frozen_stages}')
+        self.out_indices = out_indices
+        self.frozen_stages = frozen_stages
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        self.norm_eval = norm_eval
+        self.with_cp = with_cp
+
+        self.in_channels = make_divisible(32 * widen_factor, 8)
+
+        self.conv1 = ConvModule(
+            in_channels=3,
+            out_channels=self.in_channels,
+            kernel_size=3,
+            stride=2,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+        self.layers = []
+
+        for i, layer_cfg in enumerate(self.arch_settings):
+            expand_ratio, channel, num_blocks = layer_cfg
+            stride = self.strides[i]
+            dilation = self.dilations[i]
+            out_channels = make_divisible(channel * widen_factor, 8)
+            inverted_res_layer = self.make_layer(
+                out_channels=out_channels,
+                num_blocks=num_blocks,
+                stride=stride,
+                dilation=dilation,
+                expand_ratio=expand_ratio)
+            layer_name = f'layer{i + 1}'
+            self.add_module(layer_name, inverted_res_layer)
+            self.layers.append(layer_name)
+
+    def make_layer(self, out_channels, num_blocks, stride, dilation,
+                   expand_ratio):
+        """Stack InvertedResidual blocks to build a layer for MobileNetV2.
+
+        Args:
+            out_channels (int): out_channels of block.
+            num_blocks (int): Number of blocks.
+            stride (int): Stride of the first block.
+            dilation (int): Dilation of the first block.
+            expand_ratio (int): Expand the number of channels of the
+                hidden layer in InvertedResidual by this ratio.
+        """
+        layers = []
+        for i in range(num_blocks):
+            layers.append(
+                InvertedResidual(
+                    self.in_channels,
+                    out_channels,
+                    stride if i == 0 else 1,
+                    expand_ratio=expand_ratio,
+                    dilation=dilation if i == 0 else 1,
+                    conv_cfg=self.conv_cfg,
+                    norm_cfg=self.norm_cfg,
+                    act_cfg=self.act_cfg,
+                    with_cp=self.with_cp))
+            self.in_channels = out_channels
+
+        return nn.Sequential(*layers)
+
+    def init_weights(self, pretrained=None):
+        if isinstance(pretrained, str):
+            logger = logging.getLogger()
+            load_checkpoint(self, pretrained, strict=False, logger=logger)
+        elif pretrained is None:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d):
+                    kaiming_init(m)
+                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+                    constant_init(m, 1)
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def forward(self, x):
+        x = self.conv1(x)
+
+        outs = []
+        for i, layer_name in enumerate(self.layers):
+            layer = getattr(self, layer_name)
+            x = layer(x)
+            if i in self.out_indices:
+                outs.append(x)
+
+        if len(outs) == 1:
+            return outs[0]
+        else:
+            return tuple(outs)
+
+    def _freeze_stages(self):
+        if self.frozen_stages >= 0:
+            for param in self.conv1.parameters():
+                param.requires_grad = False
+        for i in range(1, self.frozen_stages + 1):
+            layer = getattr(self, f'layer{i}')
+            layer.eval()
+            for param in layer.parameters():
+                param.requires_grad = False
+
+    def train(self, mode=True):
+        super(MobileNetV2, self).train(mode)
+        self._freeze_stages()
+        if mode and self.norm_eval:
+            for m in self.modules():
+                if isinstance(m, _BatchNorm):
+                    m.eval()
diff --git a/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py b/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..16817400b4102899794fe64c9644713a4e54e2f9
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/mobilenet_v3.py
@@ -0,0 +1,255 @@
+import logging
+
+import annotator.uniformer.mmcv as mmcv
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, constant_init, kaiming_init
+from annotator.uniformer.mmcv.cnn.bricks import Conv2dAdaptivePadding
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+from ..utils import InvertedResidualV3 as InvertedResidual
+
+
+@BACKBONES.register_module()
+class MobileNetV3(nn.Module):
+    """MobileNetV3 backbone.
+
+    This backbone is the improved implementation of `Searching for MobileNetV3
+    <https://ieeexplore.ieee.org/document/9008835>`_.
+
+    Args:
+        arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
+            Default: 'small'.
+        conv_cfg (dict): Config dict for convolution layer.
+            Default: None, which means using conv2d.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='BN').
+        out_indices (tuple[int]): Output from which layer.
+            Default: (0, 1, 12).
+        frozen_stages (int): Stages to be frozen (all param fixed).
+            Default: -1, which means not freezing any parameters.
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only. Default: False.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save
+            some memory while slowing down the training speed.
+            Default: False.
+    """
+    # Parameters to build each block:
+    #     [kernel size, mid channels, out channels, with_se, act type, stride]
+    arch_settings = {
+        'small': [[3, 16, 16, True, 'ReLU', 2],  # block0 layer1 os=4
+                  [3, 72, 24, False, 'ReLU', 2],  # block1 layer2 os=8
+                  [3, 88, 24, False, 'ReLU', 1],
+                  [5, 96, 40, True, 'HSwish', 2],  # block2 layer4 os=16
+                  [5, 240, 40, True, 'HSwish', 1],
+                  [5, 240, 40, True, 'HSwish', 1],
+                  [5, 120, 48, True, 'HSwish', 1],  # block3 layer7 os=16
+                  [5, 144, 48, True, 'HSwish', 1],
+                  [5, 288, 96, True, 'HSwish', 2],  # block4 layer9 os=32
+                  [5, 576, 96, True, 'HSwish', 1],
+                  [5, 576, 96, True, 'HSwish', 1]],
+        'large': [[3, 16, 16, False, 'ReLU', 1],  # block0 layer1 os=2
+                  [3, 64, 24, False, 'ReLU', 2],  # block1 layer2 os=4
+                  [3, 72, 24, False, 'ReLU', 1],
+                  [5, 72, 40, True, 'ReLU', 2],  # block2 layer4 os=8
+                  [5, 120, 40, True, 'ReLU', 1],
+                  [5, 120, 40, True, 'ReLU', 1],
+                  [3, 240, 80, False, 'HSwish', 2],  # block3 layer7 os=16
+                  [3, 200, 80, False, 'HSwish', 1],
+                  [3, 184, 80, False, 'HSwish', 1],
+                  [3, 184, 80, False, 'HSwish', 1],
+                  [3, 480, 112, True, 'HSwish', 1],  # block4 layer11 os=16
+                  [3, 672, 112, True, 'HSwish', 1],
+                  [5, 672, 160, True, 'HSwish', 2],  # block5 layer13 os=32
+                  [5, 960, 160, True, 'HSwish', 1],
+                  [5, 960, 160, True, 'HSwish', 1]]
+    }  # yapf: disable
+
+    def __init__(self,
+                 arch='small',
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 out_indices=(0, 1, 12),
+                 frozen_stages=-1,
+                 reduction_factor=1,
+                 norm_eval=False,
+                 with_cp=False):
+        super(MobileNetV3, self).__init__()
+        assert arch in self.arch_settings
+        assert isinstance(reduction_factor, int) and reduction_factor > 0
+        assert mmcv.is_tuple_of(out_indices, int)
+        for index in out_indices:
+            if index not in range(0, len(self.arch_settings[arch]) + 2):
+                raise ValueError(
+                    'the item in out_indices must in '
+                    f'range(0, {len(self.arch_settings[arch])+2}). '
+                    f'But received {index}')
+
+        if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
+            raise ValueError('frozen_stages must be in range(-1, '
+                             f'{len(self.arch_settings[arch])+2}). '
+                             f'But received {frozen_stages}')
+        self.arch = arch
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.out_indices = out_indices
+        self.frozen_stages = frozen_stages
+        self.reduction_factor = reduction_factor
+        self.norm_eval = norm_eval
+        self.with_cp = with_cp
+        self.layers = self._make_layer()
+
+    def _make_layer(self):
+        layers = []
+
+        # build the first layer (layer0)
+        in_channels = 16
+        layer = ConvModule(
+            in_channels=3,
+            out_channels=in_channels,
+            kernel_size=3,
+            stride=2,
+            padding=1,
+            conv_cfg=dict(type='Conv2dAdaptivePadding'),
+            norm_cfg=self.norm_cfg,
+            act_cfg=dict(type='HSwish'))
+        self.add_module('layer0', layer)
+        layers.append('layer0')
+
+        layer_setting = self.arch_settings[self.arch]
+        for i, params in enumerate(layer_setting):
+            (kernel_size, mid_channels, out_channels, with_se, act,
+             stride) = params
+
+            if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
+                    i >= 8:
+                mid_channels = mid_channels // self.reduction_factor
+                out_channels = out_channels // self.reduction_factor
+
+            if with_se:
+                se_cfg = dict(
+                    channels=mid_channels,
+                    ratio=4,
+                    act_cfg=(dict(type='ReLU'),
+                             dict(type='HSigmoid', bias=3.0, divisor=6.0)))
+            else:
+                se_cfg = None
+
+            layer = InvertedResidual(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                mid_channels=mid_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                se_cfg=se_cfg,
+                with_expand_conv=(in_channels != mid_channels),
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=dict(type=act),
+                with_cp=self.with_cp)
+            in_channels = out_channels
+            layer_name = 'layer{}'.format(i + 1)
+            self.add_module(layer_name, layer)
+            layers.append(layer_name)
+
+        # build the last layer
+        # block5 layer12 os=32 for small model
+        # block6 layer16 os=32 for large model
+        layer = ConvModule(
+            in_channels=in_channels,
+            out_channels=576 if self.arch == 'small' else 960,
+            kernel_size=1,
+            stride=1,
+            dilation=4,
+            padding=0,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=dict(type='HSwish'))
+        layer_name = 'layer{}'.format(len(layer_setting) + 1)
+        self.add_module(layer_name, layer)
+        layers.append(layer_name)
+
+        # next, convert backbone MobileNetV3 to a semantic segmentation version
+        if self.arch == 'small':
+            self.layer4.depthwise_conv.conv.stride = (1, 1)
+            self.layer9.depthwise_conv.conv.stride = (1, 1)
+            for i in range(4, len(layers)):
+                layer = getattr(self, layers[i])
+                if isinstance(layer, InvertedResidual):
+                    modified_module = layer.depthwise_conv.conv
+                else:
+                    modified_module = layer.conv
+
+                if i < 9:
+                    modified_module.dilation = (2, 2)
+                    pad = 2
+                else:
+                    modified_module.dilation = (4, 4)
+                    pad = 4
+
+                if not isinstance(modified_module, Conv2dAdaptivePadding):
+                    # Adjust padding
+                    pad *= (modified_module.kernel_size[0] - 1) // 2
+                    modified_module.padding = (pad, pad)
+        else:
+            self.layer7.depthwise_conv.conv.stride = (1, 1)
+            self.layer13.depthwise_conv.conv.stride = (1, 1)
+            for i in range(7, len(layers)):
+                layer = getattr(self, layers[i])
+                if isinstance(layer, InvertedResidual):
+                    modified_module = layer.depthwise_conv.conv
+                else:
+                    modified_module = layer.conv
+
+                if i < 13:
+                    modified_module.dilation = (2, 2)
+                    pad = 2
+                else:
+                    modified_module.dilation = (4, 4)
+                    pad = 4
+
+                if not isinstance(modified_module, Conv2dAdaptivePadding):
+                    # Adjust padding
+                    pad *= (modified_module.kernel_size[0] - 1) // 2
+                    modified_module.padding = (pad, pad)
+
+        return layers
+
+    def init_weights(self, pretrained=None):
+        if isinstance(pretrained, str):
+            logger = logging.getLogger()
+            load_checkpoint(self, pretrained, strict=False, logger=logger)
+        elif pretrained is None:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d):
+                    kaiming_init(m)
+                elif isinstance(m, nn.BatchNorm2d):
+                    constant_init(m, 1)
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def forward(self, x):
+        outs = []
+        for i, layer_name in enumerate(self.layers):
+            layer = getattr(self, layer_name)
+            x = layer(x)
+            if i in self.out_indices:
+                outs.append(x)
+        return outs
+
+    def _freeze_stages(self):
+        for i in range(self.frozen_stages + 1):
+            layer = getattr(self, f'layer{i}')
+            layer.eval()
+            for param in layer.parameters():
+                param.requires_grad = False
+
+    def train(self, mode=True):
+        super(MobileNetV3, self).train(mode)
+        self._freeze_stages()
+        if mode and self.norm_eval:
+            for m in self.modules():
+                if isinstance(m, _BatchNorm):
+                    m.eval()
diff --git a/annotator/uniformer/mmseg/models/backbones/resnest.py b/annotator/uniformer/mmseg/models/backbones/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45a837f395230029e9d4194ff9f7f2f8f7067b0
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/resnest.py
@@ -0,0 +1,314 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNetV1d
+
+
+class RSoftmax(nn.Module):
+    """Radix Softmax module in ``SplitAttentionConv2d``.
+
+    Args:
+        radix (int): Radix of input.
+        groups (int): Groups of input.
+    """
+
+    def __init__(self, radix, groups):
+        super().__init__()
+        self.radix = radix
+        self.groups = groups
+
+    def forward(self, x):
+        batch = x.size(0)
+        if self.radix > 1:
+            x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
+            x = F.softmax(x, dim=1)
+            x = x.reshape(batch, -1)
+        else:
+            x = torch.sigmoid(x)
+        return x
+
+
+class SplitAttentionConv2d(nn.Module):
+    """Split-Attention Conv2d in ResNeSt.
+
+    Args:
+        in_channels (int): Same as nn.Conv2d.
+        out_channels (int): Same as nn.Conv2d.
+        kernel_size (int | tuple[int]): Same as nn.Conv2d.
+        stride (int | tuple[int]): Same as nn.Conv2d.
+        padding (int | tuple[int]): Same as nn.Conv2d.
+        dilation (int | tuple[int]): Same as nn.Conv2d.
+        groups (int): Same as nn.Conv2d.
+        radix (int): Radix of SpltAtConv2d. Default: 2
+        reduction_factor (int): Reduction factor of inter_channels. Default: 4.
+        conv_cfg (dict): Config dict for convolution layer. Default: None,
+            which means using conv2d.
+        norm_cfg (dict): Config dict for normalization layer. Default: None.
+        dcn (dict): Config dict for DCN. Default: None.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 groups=1,
+                 radix=2,
+                 reduction_factor=4,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 dcn=None):
+        super(SplitAttentionConv2d, self).__init__()
+        inter_channels = max(in_channels * radix // reduction_factor, 32)
+        self.radix = radix
+        self.groups = groups
+        self.channels = channels
+        self.with_dcn = dcn is not None
+        self.dcn = dcn
+        fallback_on_stride = False
+        if self.with_dcn:
+            fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+        if self.with_dcn and not fallback_on_stride:
+            assert conv_cfg is None, 'conv_cfg must be None for DCN'
+            conv_cfg = dcn
+        self.conv = build_conv_layer(
+            conv_cfg,
+            in_channels,
+            channels * radix,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups * radix,
+            bias=False)
+        self.norm0_name, norm0 = build_norm_layer(
+            norm_cfg, channels * radix, postfix=0)
+        self.add_module(self.norm0_name, norm0)
+        self.relu = nn.ReLU(inplace=True)
+        self.fc1 = build_conv_layer(
+            None, channels, inter_channels, 1, groups=self.groups)
+        self.norm1_name, norm1 = build_norm_layer(
+            norm_cfg, inter_channels, postfix=1)
+        self.add_module(self.norm1_name, norm1)
+        self.fc2 = build_conv_layer(
+            None, inter_channels, channels * radix, 1, groups=self.groups)
+        self.rsoftmax = RSoftmax(radix, groups)
+
+    @property
+    def norm0(self):
+        """nn.Module: the normalization layer named "norm0" """
+        return getattr(self, self.norm0_name)
+
+    @property
+    def norm1(self):
+        """nn.Module: the normalization layer named "norm1" """
+        return getattr(self, self.norm1_name)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.norm0(x)
+        x = self.relu(x)
+
+        batch, rchannel = x.shape[:2]
+        batch = x.size(0)
+        if self.radix > 1:
+            splits = x.view(batch, self.radix, -1, *x.shape[2:])
+            gap = splits.sum(dim=1)
+        else:
+            gap = x
+        gap = F.adaptive_avg_pool2d(gap, 1)
+        gap = self.fc1(gap)
+
+        gap = self.norm1(gap)
+        gap = self.relu(gap)
+
+        atten = self.fc2(gap)
+        atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+        if self.radix > 1:
+            attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
+            out = torch.sum(attens * splits, dim=1)
+        else:
+            out = atten * x
+        return out.contiguous()
+
+
+class Bottleneck(_Bottleneck):
+    """Bottleneck block for ResNeSt.
+
+    Args:
+        inplane (int): Input planes of this block.
+        planes (int): Middle planes of this block.
+        groups (int): Groups of conv2.
+        width_per_group (int): Width per group of conv2. 64x4d indicates
+            ``groups=64, width_per_group=4`` and 32x8d indicates
+            ``groups=32, width_per_group=8``.
+        radix (int): Radix of SpltAtConv2d. Default: 2
+        reduction_factor (int): Reduction factor of inter_channels in
+            SplitAttentionConv2d. Default: 4.
+        avg_down_stride (bool): Whether to use average pool for stride in
+            Bottleneck. Default: True.
+        kwargs (dict): Key word arguments for base class.
+    """
+    expansion = 4
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 groups=1,
+                 base_width=4,
+                 base_channels=64,
+                 radix=2,
+                 reduction_factor=4,
+                 avg_down_stride=True,
+                 **kwargs):
+        """Bottleneck block for ResNeSt."""
+        super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+        if groups == 1:
+            width = self.planes
+        else:
+            width = math.floor(self.planes *
+                               (base_width / base_channels)) * groups
+
+        self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
+
+        self.norm1_name, norm1 = build_norm_layer(
+            self.norm_cfg, width, postfix=1)
+        self.norm3_name, norm3 = build_norm_layer(
+            self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+        self.conv1 = build_conv_layer(
+            self.conv_cfg,
+            self.inplanes,
+            width,
+            kernel_size=1,
+            stride=self.conv1_stride,
+            bias=False)
+        self.add_module(self.norm1_name, norm1)
+        self.with_modulated_dcn = False
+        self.conv2 = SplitAttentionConv2d(
+            width,
+            width,
+            kernel_size=3,
+            stride=1 if self.avg_down_stride else self.conv2_stride,
+            padding=self.dilation,
+            dilation=self.dilation,
+            groups=groups,
+            radix=radix,
+            reduction_factor=reduction_factor,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            dcn=self.dcn)
+        delattr(self, self.norm2_name)
+
+        if self.avg_down_stride:
+            self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
+
+        self.conv3 = build_conv_layer(
+            self.conv_cfg,
+            width,
+            self.planes * self.expansion,
+            kernel_size=1,
+            bias=False)
+        self.add_module(self.norm3_name, norm3)
+
+    def forward(self, x):
+
+        def _inner_forward(x):
+            identity = x
+
+            out = self.conv1(x)
+            out = self.norm1(out)
+            out = self.relu(out)
+
+            if self.with_plugins:
+                out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+            out = self.conv2(out)
+
+            if self.avg_down_stride:
+                out = self.avd_layer(out)
+
+            if self.with_plugins:
+                out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+            out = self.conv3(out)
+            out = self.norm3(out)
+
+            if self.with_plugins:
+                out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+            if self.downsample is not None:
+                identity = self.downsample(x)
+
+            out += identity
+
+            return out
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        out = self.relu(out)
+
+        return out
+
+
+@BACKBONES.register_module()
+class ResNeSt(ResNetV1d):
+    """ResNeSt backbone.
+
+    Args:
+        groups (int): Number of groups of Bottleneck. Default: 1
+        base_width (int): Base width of Bottleneck. Default: 4
+        radix (int): Radix of SpltAtConv2d. Default: 2
+        reduction_factor (int): Reduction factor of inter_channels in
+            SplitAttentionConv2d. Default: 4.
+        avg_down_stride (bool): Whether to use average pool for stride in
+            Bottleneck. Default: True.
+        kwargs (dict): Keyword arguments for ResNet.
+    """
+
+    arch_settings = {
+        50: (Bottleneck, (3, 4, 6, 3)),
+        101: (Bottleneck, (3, 4, 23, 3)),
+        152: (Bottleneck, (3, 8, 36, 3)),
+        200: (Bottleneck, (3, 24, 36, 3))
+    }
+
+    def __init__(self,
+                 groups=1,
+                 base_width=4,
+                 radix=2,
+                 reduction_factor=4,
+                 avg_down_stride=True,
+                 **kwargs):
+        self.groups = groups
+        self.base_width = base_width
+        self.radix = radix
+        self.reduction_factor = reduction_factor
+        self.avg_down_stride = avg_down_stride
+        super(ResNeSt, self).__init__(**kwargs)
+
+    def make_res_layer(self, **kwargs):
+        """Pack all blocks in a stage into a ``ResLayer``."""
+        return ResLayer(
+            groups=self.groups,
+            base_width=self.base_width,
+            base_channels=self.base_channels,
+            radix=self.radix,
+            reduction_factor=self.reduction_factor,
+            avg_down_stride=self.avg_down_stride,
+            **kwargs)
diff --git a/annotator/uniformer/mmseg/models/backbones/resnet.py b/annotator/uniformer/mmseg/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e52bf048d28ecb069db4728e5f05ad85ac53198
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/resnet.py
@@ -0,0 +1,688 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
+                      constant_init, kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import ResLayer
+
+
+class BasicBlock(nn.Module):
+    """Basic block for ResNet."""
+
+    expansion = 1
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 stride=1,
+                 dilation=1,
+                 downsample=None,
+                 style='pytorch',
+                 with_cp=False,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 dcn=None,
+                 plugins=None):
+        super(BasicBlock, self).__init__()
+        assert dcn is None, 'Not implemented yet.'
+        assert plugins is None, 'Not implemented yet.'
+
+        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+
+        self.conv1 = build_conv_layer(
+            conv_cfg,
+            inplanes,
+            planes,
+            3,
+            stride=stride,
+            padding=dilation,
+            dilation=dilation,
+            bias=False)
+        self.add_module(self.norm1_name, norm1)
+        self.conv2 = build_conv_layer(
+            conv_cfg, planes, planes, 3, padding=1, bias=False)
+        self.add_module(self.norm2_name, norm2)
+
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+        self.dilation = dilation
+        self.with_cp = with_cp
+
+    @property
+    def norm1(self):
+        """nn.Module: normalization layer after the first convolution layer"""
+        return getattr(self, self.norm1_name)
+
+    @property
+    def norm2(self):
+        """nn.Module: normalization layer after the second convolution layer"""
+        return getattr(self, self.norm2_name)
+
+    def forward(self, x):
+        """Forward function."""
+
+        def _inner_forward(x):
+            identity = x
+
+            out = self.conv1(x)
+            out = self.norm1(out)
+            out = self.relu(out)
+
+            out = self.conv2(out)
+            out = self.norm2(out)
+
+            if self.downsample is not None:
+                identity = self.downsample(x)
+
+            out += identity
+
+            return out
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    """Bottleneck block for ResNet.
+
+    If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+    "caffe", the stride-two layer is the first 1x1 conv layer.
+    """
+
+    expansion = 4
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 stride=1,
+                 dilation=1,
+                 downsample=None,
+                 style='pytorch',
+                 with_cp=False,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 dcn=None,
+                 plugins=None):
+        super(Bottleneck, self).__init__()
+        assert style in ['pytorch', 'caffe']
+        assert dcn is None or isinstance(dcn, dict)
+        assert plugins is None or isinstance(plugins, list)
+        if plugins is not None:
+            allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
+            assert all(p['position'] in allowed_position for p in plugins)
+
+        self.inplanes = inplanes
+        self.planes = planes
+        self.stride = stride
+        self.dilation = dilation
+        self.style = style
+        self.with_cp = with_cp
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.dcn = dcn
+        self.with_dcn = dcn is not None
+        self.plugins = plugins
+        self.with_plugins = plugins is not None
+
+        if self.with_plugins:
+            # collect plugins for conv1/conv2/conv3
+            self.after_conv1_plugins = [
+                plugin['cfg'] for plugin in plugins
+                if plugin['position'] == 'after_conv1'
+            ]
+            self.after_conv2_plugins = [
+                plugin['cfg'] for plugin in plugins
+                if plugin['position'] == 'after_conv2'
+            ]
+            self.after_conv3_plugins = [
+                plugin['cfg'] for plugin in plugins
+                if plugin['position'] == 'after_conv3'
+            ]
+
+        if self.style == 'pytorch':
+            self.conv1_stride = 1
+            self.conv2_stride = stride
+        else:
+            self.conv1_stride = stride
+            self.conv2_stride = 1
+
+        self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+        self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+        self.norm3_name, norm3 = build_norm_layer(
+            norm_cfg, planes * self.expansion, postfix=3)
+
+        self.conv1 = build_conv_layer(
+            conv_cfg,
+            inplanes,
+            planes,
+            kernel_size=1,
+            stride=self.conv1_stride,
+            bias=False)
+        self.add_module(self.norm1_name, norm1)
+        fallback_on_stride = False
+        if self.with_dcn:
+            fallback_on_stride = dcn.pop('fallback_on_stride', False)
+        if not self.with_dcn or fallback_on_stride:
+            self.conv2 = build_conv_layer(
+                conv_cfg,
+                planes,
+                planes,
+                kernel_size=3,
+                stride=self.conv2_stride,
+                padding=dilation,
+                dilation=dilation,
+                bias=False)
+        else:
+            assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+            self.conv2 = build_conv_layer(
+                dcn,
+                planes,
+                planes,
+                kernel_size=3,
+                stride=self.conv2_stride,
+                padding=dilation,
+                dilation=dilation,
+                bias=False)
+
+        self.add_module(self.norm2_name, norm2)
+        self.conv3 = build_conv_layer(
+            conv_cfg,
+            planes,
+            planes * self.expansion,
+            kernel_size=1,
+            bias=False)
+        self.add_module(self.norm3_name, norm3)
+
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+
+        if self.with_plugins:
+            self.after_conv1_plugin_names = self.make_block_plugins(
+                planes, self.after_conv1_plugins)
+            self.after_conv2_plugin_names = self.make_block_plugins(
+                planes, self.after_conv2_plugins)
+            self.after_conv3_plugin_names = self.make_block_plugins(
+                planes * self.expansion, self.after_conv3_plugins)
+
+    def make_block_plugins(self, in_channels, plugins):
+        """make plugins for block.
+
+        Args:
+            in_channels (int): Input channels of plugin.
+            plugins (list[dict]): List of plugins cfg to build.
+
+        Returns:
+            list[str]: List of the names of plugin.
+        """
+        assert isinstance(plugins, list)
+        plugin_names = []
+        for plugin in plugins:
+            plugin = plugin.copy()
+            name, layer = build_plugin_layer(
+                plugin,
+                in_channels=in_channels,
+                postfix=plugin.pop('postfix', ''))
+            assert not hasattr(self, name), f'duplicate plugin {name}'
+            self.add_module(name, layer)
+            plugin_names.append(name)
+        return plugin_names
+
+    def forward_plugin(self, x, plugin_names):
+        """Forward function for plugins."""
+        out = x
+        for name in plugin_names:
+            out = getattr(self, name)(x)
+        return out
+
+    @property
+    def norm1(self):
+        """nn.Module: normalization layer after the first convolution layer"""
+        return getattr(self, self.norm1_name)
+
+    @property
+    def norm2(self):
+        """nn.Module: normalization layer after the second convolution layer"""
+        return getattr(self, self.norm2_name)
+
+    @property
+    def norm3(self):
+        """nn.Module: normalization layer after the third convolution layer"""
+        return getattr(self, self.norm3_name)
+
+    def forward(self, x):
+        """Forward function."""
+
+        def _inner_forward(x):
+            identity = x
+
+            out = self.conv1(x)
+            out = self.norm1(out)
+            out = self.relu(out)
+
+            if self.with_plugins:
+                out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+            out = self.conv2(out)
+            out = self.norm2(out)
+            out = self.relu(out)
+
+            if self.with_plugins:
+                out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+            out = self.conv3(out)
+            out = self.norm3(out)
+
+            if self.with_plugins:
+                out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+            if self.downsample is not None:
+                identity = self.downsample(x)
+
+            out += identity
+
+            return out
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        out = self.relu(out)
+
+        return out
+
+
+@BACKBONES.register_module()
+class ResNet(nn.Module):
+    """ResNet backbone.
+
+    Args:
+        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+        in_channels (int): Number of input image channels. Default" 3.
+        stem_channels (int): Number of stem channels. Default: 64.
+        base_channels (int): Number of base channels of res layer. Default: 64.
+        num_stages (int): Resnet stages, normally 4.
+        strides (Sequence[int]): Strides of the first block of each stage.
+        dilations (Sequence[int]): Dilation of each stage.
+        out_indices (Sequence[int]): Output from which stages.
+        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+            layer is the 3x3 conv layer, otherwise the stride-two layer is
+            the first 1x1 conv layer.
+        deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+        avg_down (bool): Use AvgPool instead of stride conv when
+            downsampling in the bottleneck.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters.
+        norm_cfg (dict): Dictionary to construct and config norm layer.
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only.
+        plugins (list[dict]): List of plugins for stages, each dict contains:
+
+            - cfg (dict, required): Cfg dict to build plugin.
+
+            - position (str, required): Position inside block to insert plugin,
+            options: 'after_conv1', 'after_conv2', 'after_conv3'.
+
+            - stages (tuple[bool], optional): Stages to apply plugin, length
+            should be same as 'num_stages'
+        multi_grid (Sequence[int]|None): Multi grid dilation rates of last
+            stage. Default: None
+        contract_dilation (bool): Whether contract first dilation of each layer
+            Default: False
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed.
+        zero_init_residual (bool): Whether to use zero init for last norm layer
+            in resblocks to let them behave as identity.
+
+    Example:
+        >>> from annotator.uniformer.mmseg.models import ResNet
+        >>> import torch
+        >>> self = ResNet(depth=18)
+        >>> self.eval()
+        >>> inputs = torch.rand(1, 3, 32, 32)
+        >>> level_outputs = self.forward(inputs)
+        >>> for level_out in level_outputs:
+        ...     print(tuple(level_out.shape))
+        (1, 64, 8, 8)
+        (1, 128, 4, 4)
+        (1, 256, 2, 2)
+        (1, 512, 1, 1)
+    """
+
+    arch_settings = {
+        18: (BasicBlock, (2, 2, 2, 2)),
+        34: (BasicBlock, (3, 4, 6, 3)),
+        50: (Bottleneck, (3, 4, 6, 3)),
+        101: (Bottleneck, (3, 4, 23, 3)),
+        152: (Bottleneck, (3, 8, 36, 3))
+    }
+
+    def __init__(self,
+                 depth,
+                 in_channels=3,
+                 stem_channels=64,
+                 base_channels=64,
+                 num_stages=4,
+                 strides=(1, 2, 2, 2),
+                 dilations=(1, 1, 1, 1),
+                 out_indices=(0, 1, 2, 3),
+                 style='pytorch',
+                 deep_stem=False,
+                 avg_down=False,
+                 frozen_stages=-1,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN', requires_grad=True),
+                 norm_eval=False,
+                 dcn=None,
+                 stage_with_dcn=(False, False, False, False),
+                 plugins=None,
+                 multi_grid=None,
+                 contract_dilation=False,
+                 with_cp=False,
+                 zero_init_residual=True):
+        super(ResNet, self).__init__()
+        if depth not in self.arch_settings:
+            raise KeyError(f'invalid depth {depth} for resnet')
+        self.depth = depth
+        self.stem_channels = stem_channels
+        self.base_channels = base_channels
+        self.num_stages = num_stages
+        assert num_stages >= 1 and num_stages <= 4
+        self.strides = strides
+        self.dilations = dilations
+        assert len(strides) == len(dilations) == num_stages
+        self.out_indices = out_indices
+        assert max(out_indices) < num_stages
+        self.style = style
+        self.deep_stem = deep_stem
+        self.avg_down = avg_down
+        self.frozen_stages = frozen_stages
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.with_cp = with_cp
+        self.norm_eval = norm_eval
+        self.dcn = dcn
+        self.stage_with_dcn = stage_with_dcn
+        if dcn is not None:
+            assert len(stage_with_dcn) == num_stages
+        self.plugins = plugins
+        self.multi_grid = multi_grid
+        self.contract_dilation = contract_dilation
+        self.zero_init_residual = zero_init_residual
+        self.block, stage_blocks = self.arch_settings[depth]
+        self.stage_blocks = stage_blocks[:num_stages]
+        self.inplanes = stem_channels
+
+        self._make_stem_layer(in_channels, stem_channels)
+
+        self.res_layers = []
+        for i, num_blocks in enumerate(self.stage_blocks):
+            stride = strides[i]
+            dilation = dilations[i]
+            dcn = self.dcn if self.stage_with_dcn[i] else None
+            if plugins is not None:
+                stage_plugins = self.make_stage_plugins(plugins, i)
+            else:
+                stage_plugins = None
+            # multi grid is applied to last layer only
+            stage_multi_grid = multi_grid if i == len(
+                self.stage_blocks) - 1 else None
+            planes = base_channels * 2**i
+            res_layer = self.make_res_layer(
+                block=self.block,
+                inplanes=self.inplanes,
+                planes=planes,
+                num_blocks=num_blocks,
+                stride=stride,
+                dilation=dilation,
+                style=self.style,
+                avg_down=self.avg_down,
+                with_cp=with_cp,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                dcn=dcn,
+                plugins=stage_plugins,
+                multi_grid=stage_multi_grid,
+                contract_dilation=contract_dilation)
+            self.inplanes = planes * self.block.expansion
+            layer_name = f'layer{i+1}'
+            self.add_module(layer_name, res_layer)
+            self.res_layers.append(layer_name)
+
+        self._freeze_stages()
+
+        self.feat_dim = self.block.expansion * base_channels * 2**(
+            len(self.stage_blocks) - 1)
+
+    def make_stage_plugins(self, plugins, stage_idx):
+        """make plugins for ResNet 'stage_idx'th stage .
+
+        Currently we support to insert 'context_block',
+        'empirical_attention_block', 'nonlocal_block' into the backbone like
+        ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
+        Bottleneck.
+
+        An example of plugins format could be :
+        >>> plugins=[
+        ...     dict(cfg=dict(type='xxx', arg1='xxx'),
+        ...          stages=(False, True, True, True),
+        ...          position='after_conv2'),
+        ...     dict(cfg=dict(type='yyy'),
+        ...          stages=(True, True, True, True),
+        ...          position='after_conv3'),
+        ...     dict(cfg=dict(type='zzz', postfix='1'),
+        ...          stages=(True, True, True, True),
+        ...          position='after_conv3'),
+        ...     dict(cfg=dict(type='zzz', postfix='2'),
+        ...          stages=(True, True, True, True),
+        ...          position='after_conv3')
+        ... ]
+        >>> self = ResNet(depth=18)
+        >>> stage_plugins = self.make_stage_plugins(plugins, 0)
+        >>> assert len(stage_plugins) == 3
+
+        Suppose 'stage_idx=0', the structure of blocks in the stage would be:
+            conv1-> conv2->conv3->yyy->zzz1->zzz2
+        Suppose 'stage_idx=1', the structure of blocks in the stage would be:
+            conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
+
+        If stages is missing, the plugin would be applied to all stages.
+
+        Args:
+            plugins (list[dict]): List of plugins cfg to build. The postfix is
+                required if multiple same type plugins are inserted.
+            stage_idx (int): Index of stage to build
+
+        Returns:
+            list[dict]: Plugins for current stage
+        """
+        stage_plugins = []
+        for plugin in plugins:
+            plugin = plugin.copy()
+            stages = plugin.pop('stages', None)
+            assert stages is None or len(stages) == self.num_stages
+            # whether to insert plugin into current stage
+            if stages is None or stages[stage_idx]:
+                stage_plugins.append(plugin)
+
+        return stage_plugins
+
+    def make_res_layer(self, **kwargs):
+        """Pack all blocks in a stage into a ``ResLayer``."""
+        return ResLayer(**kwargs)
+
+    @property
+    def norm1(self):
+        """nn.Module: the normalization layer named "norm1" """
+        return getattr(self, self.norm1_name)
+
+    def _make_stem_layer(self, in_channels, stem_channels):
+        """Make stem layer for ResNet."""
+        if self.deep_stem:
+            self.stem = nn.Sequential(
+                build_conv_layer(
+                    self.conv_cfg,
+                    in_channels,
+                    stem_channels // 2,
+                    kernel_size=3,
+                    stride=2,
+                    padding=1,
+                    bias=False),
+                build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+                nn.ReLU(inplace=True),
+                build_conv_layer(
+                    self.conv_cfg,
+                    stem_channels // 2,
+                    stem_channels // 2,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=False),
+                build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+                nn.ReLU(inplace=True),
+                build_conv_layer(
+                    self.conv_cfg,
+                    stem_channels // 2,
+                    stem_channels,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=False),
+                build_norm_layer(self.norm_cfg, stem_channels)[1],
+                nn.ReLU(inplace=True))
+        else:
+            self.conv1 = build_conv_layer(
+                self.conv_cfg,
+                in_channels,
+                stem_channels,
+                kernel_size=7,
+                stride=2,
+                padding=3,
+                bias=False)
+            self.norm1_name, norm1 = build_norm_layer(
+                self.norm_cfg, stem_channels, postfix=1)
+            self.add_module(self.norm1_name, norm1)
+            self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+    def _freeze_stages(self):
+        """Freeze stages param and norm stats."""
+        if self.frozen_stages >= 0:
+            if self.deep_stem:
+                self.stem.eval()
+                for param in self.stem.parameters():
+                    param.requires_grad = False
+            else:
+                self.norm1.eval()
+                for m in [self.conv1, self.norm1]:
+                    for param in m.parameters():
+                        param.requires_grad = False
+
+        for i in range(1, self.frozen_stages + 1):
+            m = getattr(self, f'layer{i}')
+            m.eval()
+            for param in m.parameters():
+                param.requires_grad = False
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        if isinstance(pretrained, str):
+            logger = get_root_logger()
+            load_checkpoint(self, pretrained, strict=False, logger=logger)
+        elif pretrained is None:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d):
+                    kaiming_init(m)
+                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+                    constant_init(m, 1)
+
+            if self.dcn is not None:
+                for m in self.modules():
+                    if isinstance(m, Bottleneck) and hasattr(
+                            m, 'conv2_offset'):
+                        constant_init(m.conv2_offset, 0)
+
+            if self.zero_init_residual:
+                for m in self.modules():
+                    if isinstance(m, Bottleneck):
+                        constant_init(m.norm3, 0)
+                    elif isinstance(m, BasicBlock):
+                        constant_init(m.norm2, 0)
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def forward(self, x):
+        """Forward function."""
+        if self.deep_stem:
+            x = self.stem(x)
+        else:
+            x = self.conv1(x)
+            x = self.norm1(x)
+            x = self.relu(x)
+        x = self.maxpool(x)
+        outs = []
+        for i, layer_name in enumerate(self.res_layers):
+            res_layer = getattr(self, layer_name)
+            x = res_layer(x)
+            if i in self.out_indices:
+                outs.append(x)
+        return tuple(outs)
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep normalization layer
+        freezed."""
+        super(ResNet, self).train(mode)
+        self._freeze_stages()
+        if mode and self.norm_eval:
+            for m in self.modules():
+                # trick: eval have effect on BatchNorm only
+                if isinstance(m, _BatchNorm):
+                    m.eval()
+
+
+@BACKBONES.register_module()
+class ResNetV1c(ResNet):
+    """ResNetV1c variant described in [1]_.
+
+    Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
+    in the input stem with three 3x3 convs.
+
+    References:
+        .. [1] https://arxiv.org/pdf/1812.01187.pdf
+    """
+
+    def __init__(self, **kwargs):
+        super(ResNetV1c, self).__init__(
+            deep_stem=True, avg_down=False, **kwargs)
+
+
+@BACKBONES.register_module()
+class ResNetV1d(ResNet):
+    """ResNetV1d variant described in [1]_.
+
+    Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+    the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+    avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+    """
+
+    def __init__(self, **kwargs):
+        super(ResNetV1d, self).__init__(
+            deep_stem=True, avg_down=True, **kwargs)
diff --git a/annotator/uniformer/mmseg/models/backbones/resnext.py b/annotator/uniformer/mmseg/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..962249ad6fd9b50960ad6426f7ce3cac6ed8c5bc
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/resnext.py
@@ -0,0 +1,145 @@
+import math
+
+from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottleneck(_Bottleneck):
+    """Bottleneck block for ResNeXt.
+
+    If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
+    "caffe", the stride-two layer is the first 1x1 conv layer.
+    """
+
+    def __init__(self,
+                 inplanes,
+                 planes,
+                 groups=1,
+                 base_width=4,
+                 base_channels=64,
+                 **kwargs):
+        super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+        if groups == 1:
+            width = self.planes
+        else:
+            width = math.floor(self.planes *
+                               (base_width / base_channels)) * groups
+
+        self.norm1_name, norm1 = build_norm_layer(
+            self.norm_cfg, width, postfix=1)
+        self.norm2_name, norm2 = build_norm_layer(
+            self.norm_cfg, width, postfix=2)
+        self.norm3_name, norm3 = build_norm_layer(
+            self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+        self.conv1 = build_conv_layer(
+            self.conv_cfg,
+            self.inplanes,
+            width,
+            kernel_size=1,
+            stride=self.conv1_stride,
+            bias=False)
+        self.add_module(self.norm1_name, norm1)
+        fallback_on_stride = False
+        self.with_modulated_dcn = False
+        if self.with_dcn:
+            fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+        if not self.with_dcn or fallback_on_stride:
+            self.conv2 = build_conv_layer(
+                self.conv_cfg,
+                width,
+                width,
+                kernel_size=3,
+                stride=self.conv2_stride,
+                padding=self.dilation,
+                dilation=self.dilation,
+                groups=groups,
+                bias=False)
+        else:
+            assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+            self.conv2 = build_conv_layer(
+                self.dcn,
+                width,
+                width,
+                kernel_size=3,
+                stride=self.conv2_stride,
+                padding=self.dilation,
+                dilation=self.dilation,
+                groups=groups,
+                bias=False)
+
+        self.add_module(self.norm2_name, norm2)
+        self.conv3 = build_conv_layer(
+            self.conv_cfg,
+            width,
+            self.planes * self.expansion,
+            kernel_size=1,
+            bias=False)
+        self.add_module(self.norm3_name, norm3)
+
+
+@BACKBONES.register_module()
+class ResNeXt(ResNet):
+    """ResNeXt backbone.
+
+    Args:
+        depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+        in_channels (int): Number of input image channels. Normally 3.
+        num_stages (int): Resnet stages, normally 4.
+        groups (int): Group of resnext.
+        base_width (int): Base width of resnext.
+        strides (Sequence[int]): Strides of the first block of each stage.
+        dilations (Sequence[int]): Dilation of each stage.
+        out_indices (Sequence[int]): Output from which stages.
+        style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+            layer is the 3x3 conv layer, otherwise the stride-two layer is
+            the first 1x1 conv layer.
+        frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+            not freezing any parameters.
+        norm_cfg (dict): dictionary to construct and config norm layer.
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed.
+        zero_init_residual (bool): whether to use zero init for last norm layer
+            in resblocks to let them behave as identity.
+
+    Example:
+        >>> from annotator.uniformer.mmseg.models import ResNeXt
+        >>> import torch
+        >>> self = ResNeXt(depth=50)
+        >>> self.eval()
+        >>> inputs = torch.rand(1, 3, 32, 32)
+        >>> level_outputs = self.forward(inputs)
+        >>> for level_out in level_outputs:
+        ...     print(tuple(level_out.shape))
+        (1, 256, 8, 8)
+        (1, 512, 4, 4)
+        (1, 1024, 2, 2)
+        (1, 2048, 1, 1)
+    """
+
+    arch_settings = {
+        50: (Bottleneck, (3, 4, 6, 3)),
+        101: (Bottleneck, (3, 4, 23, 3)),
+        152: (Bottleneck, (3, 8, 36, 3))
+    }
+
+    def __init__(self, groups=1, base_width=4, **kwargs):
+        self.groups = groups
+        self.base_width = base_width
+        super(ResNeXt, self).__init__(**kwargs)
+
+    def make_res_layer(self, **kwargs):
+        """Pack all blocks in a stage into a ``ResLayer``"""
+        return ResLayer(
+            groups=self.groups,
+            base_width=self.base_width,
+            base_channels=self.base_channels,
+            **kwargs)
diff --git a/annotator/uniformer/mmseg/models/backbones/unet.py b/annotator/uniformer/mmseg/models/backbones/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..82caa16a94c195c192a2a920fb7bc7e60f0f3ce3
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/unet.py
@@ -0,0 +1,429 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (UPSAMPLE_LAYERS, ConvModule, build_activation_layer,
+                      build_norm_layer, constant_init, kaiming_init)
+from annotator.uniformer.mmcv.runner import load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import UpConvBlock
+
+
+class BasicConvBlock(nn.Module):
+    """Basic convolutional block for UNet.
+
+    This module consists of several plain convolutional layers.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        num_convs (int): Number of convolutional layers. Default: 2.
+        stride (int): Whether use stride convolution to downsample
+            the input feature map. If stride=2, it only uses stride convolution
+            in the first convolutional layer to downsample the input feature
+            map. Options are 1 or 2. Default: 1.
+        dilation (int): Whether use dilated convolution to expand the
+            receptive field. Set dilation rate of each convolutional layer and
+            the dilation rate of the first convolutional layer is always 1.
+            Default: 1.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+        conv_cfg (dict | None): Config dict for convolution layer.
+            Default: None.
+        norm_cfg (dict | None): Config dict for normalization layer.
+            Default: dict(type='BN').
+        act_cfg (dict | None): Config dict for activation layer in ConvModule.
+            Default: dict(type='ReLU').
+        dcn (bool): Use deformable convolution in convolutional layer or not.
+            Default: None.
+        plugins (dict): plugins for convolutional layers. Default: None.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 num_convs=2,
+                 stride=1,
+                 dilation=1,
+                 with_cp=False,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 dcn=None,
+                 plugins=None):
+        super(BasicConvBlock, self).__init__()
+        assert dcn is None, 'Not implemented yet.'
+        assert plugins is None, 'Not implemented yet.'
+
+        self.with_cp = with_cp
+        convs = []
+        for i in range(num_convs):
+            convs.append(
+                ConvModule(
+                    in_channels=in_channels if i == 0 else out_channels,
+                    out_channels=out_channels,
+                    kernel_size=3,
+                    stride=stride if i == 0 else 1,
+                    dilation=1 if i == 0 else dilation,
+                    padding=1 if i == 0 else dilation,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+
+        self.convs = nn.Sequential(*convs)
+
+    def forward(self, x):
+        """Forward function."""
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(self.convs, x)
+        else:
+            out = self.convs(x)
+        return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class DeconvModule(nn.Module):
+    """Deconvolution upsample module in decoder for UNet (2X upsample).
+
+    This module uses deconvolution to upsample feature map in the decoder
+    of UNet.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+        norm_cfg (dict | None): Config dict for normalization layer.
+            Default: dict(type='BN').
+        act_cfg (dict | None): Config dict for activation layer in ConvModule.
+            Default: dict(type='ReLU').
+        kernel_size (int): Kernel size of the convolutional layer. Default: 4.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 with_cp=False,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 *,
+                 kernel_size=4,
+                 scale_factor=2):
+        super(DeconvModule, self).__init__()
+
+        assert (kernel_size - scale_factor >= 0) and\
+               (kernel_size - scale_factor) % 2 == 0,\
+               f'kernel_size should be greater than or equal to scale_factor '\
+               f'and (kernel_size - scale_factor) should be even numbers, '\
+               f'while the kernel size is {kernel_size} and scale_factor is '\
+               f'{scale_factor}.'
+
+        stride = scale_factor
+        padding = (kernel_size - scale_factor) // 2
+        self.with_cp = with_cp
+        deconv = nn.ConvTranspose2d(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding)
+
+        norm_name, norm = build_norm_layer(norm_cfg, out_channels)
+        activate = build_activation_layer(act_cfg)
+        self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
+
+    def forward(self, x):
+        """Forward function."""
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(self.deconv_upsamping, x)
+        else:
+            out = self.deconv_upsamping(x)
+        return out
+
+
+@UPSAMPLE_LAYERS.register_module()
+class InterpConv(nn.Module):
+    """Interpolation upsample module in decoder for UNet.
+
+    This module uses interpolation to upsample feature map in the decoder
+    of UNet. It consists of one interpolation upsample layer and one
+    convolutional layer. It can be one interpolation upsample layer followed
+    by one convolutional layer (conv_first=False) or one convolutional layer
+    followed by one interpolation upsample layer (conv_first=True).
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+        norm_cfg (dict | None): Config dict for normalization layer.
+            Default: dict(type='BN').
+        act_cfg (dict | None): Config dict for activation layer in ConvModule.
+            Default: dict(type='ReLU').
+        conv_cfg (dict | None): Config dict for convolution layer.
+            Default: None.
+        conv_first (bool): Whether convolutional layer or interpolation
+            upsample layer first. Default: False. It means interpolation
+            upsample layer followed by one convolutional layer.
+        kernel_size (int): Kernel size of the convolutional layer. Default: 1.
+        stride (int): Stride of the convolutional layer. Default: 1.
+        padding (int): Padding of the convolutional layer. Default: 1.
+        upsample_cfg (dict): Interpolation config of the upsample layer.
+            Default: dict(
+                scale_factor=2, mode='bilinear', align_corners=False).
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 with_cp=False,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 *,
+                 conv_cfg=None,
+                 conv_first=False,
+                 kernel_size=1,
+                 stride=1,
+                 padding=0,
+                 upsample_cfg=dict(
+                     scale_factor=2, mode='bilinear', align_corners=False)):
+        super(InterpConv, self).__init__()
+
+        self.with_cp = with_cp
+        conv = ConvModule(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg)
+        upsample = nn.Upsample(**upsample_cfg)
+        if conv_first:
+            self.interp_upsample = nn.Sequential(conv, upsample)
+        else:
+            self.interp_upsample = nn.Sequential(upsample, conv)
+
+    def forward(self, x):
+        """Forward function."""
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(self.interp_upsample, x)
+        else:
+            out = self.interp_upsample(x)
+        return out
+
+
+@BACKBONES.register_module()
+class UNet(nn.Module):
+    """UNet backbone.
+    U-Net: Convolutional Networks for Biomedical Image Segmentation.
+    https://arxiv.org/pdf/1505.04597.pdf
+
+    Args:
+        in_channels (int): Number of input image channels. Default" 3.
+        base_channels (int): Number of base channels of each stage.
+            The output channels of the first stage. Default: 64.
+        num_stages (int): Number of stages in encoder, normally 5. Default: 5.
+        strides (Sequence[int 1 | 2]): Strides of each stage in encoder.
+            len(strides) is equal to num_stages. Normally the stride of the
+            first stage in encoder is 1. If strides[i]=2, it uses stride
+            convolution to downsample in the correspondence encoder stage.
+            Default: (1, 1, 1, 1, 1).
+        enc_num_convs (Sequence[int]): Number of convolutional layers in the
+            convolution block of the correspondence encoder stage.
+            Default: (2, 2, 2, 2, 2).
+        dec_num_convs (Sequence[int]): Number of convolutional layers in the
+            convolution block of the correspondence decoder stage.
+            Default: (2, 2, 2, 2).
+        downsamples (Sequence[int]): Whether use MaxPool to downsample the
+            feature map after the first stage of encoder
+            (stages: [1, num_stages)). If the correspondence encoder stage use
+            stride convolution (strides[i]=2), it will never use MaxPool to
+            downsample, even downsamples[i-1]=True.
+            Default: (True, True, True, True).
+        enc_dilations (Sequence[int]): Dilation rate of each stage in encoder.
+            Default: (1, 1, 1, 1, 1).
+        dec_dilations (Sequence[int]): Dilation rate of each stage in decoder.
+            Default: (1, 1, 1, 1).
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+        conv_cfg (dict | None): Config dict for convolution layer.
+            Default: None.
+        norm_cfg (dict | None): Config dict for normalization layer.
+            Default: dict(type='BN').
+        act_cfg (dict | None): Config dict for activation layer in ConvModule.
+            Default: dict(type='ReLU').
+        upsample_cfg (dict): The upsample config of the upsample module in
+            decoder. Default: dict(type='InterpConv').
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only. Default: False.
+        dcn (bool): Use deformable convolution in convolutional layer or not.
+            Default: None.
+        plugins (dict): plugins for convolutional layers. Default: None.
+
+    Notice:
+        The input image size should be divisible by the whole downsample rate
+        of the encoder. More detail of the whole downsample rate can be found
+        in UNet._check_input_divisible.
+
+    """
+
+    def __init__(self,
+                 in_channels=3,
+                 base_channels=64,
+                 num_stages=5,
+                 strides=(1, 1, 1, 1, 1),
+                 enc_num_convs=(2, 2, 2, 2, 2),
+                 dec_num_convs=(2, 2, 2, 2),
+                 downsamples=(True, True, True, True),
+                 enc_dilations=(1, 1, 1, 1, 1),
+                 dec_dilations=(1, 1, 1, 1),
+                 with_cp=False,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 upsample_cfg=dict(type='InterpConv'),
+                 norm_eval=False,
+                 dcn=None,
+                 plugins=None):
+        super(UNet, self).__init__()
+        assert dcn is None, 'Not implemented yet.'
+        assert plugins is None, 'Not implemented yet.'
+        assert len(strides) == num_stages, \
+            'The length of strides should be equal to num_stages, '\
+            f'while the strides is {strides}, the length of '\
+            f'strides is {len(strides)}, and the num_stages is '\
+            f'{num_stages}.'
+        assert len(enc_num_convs) == num_stages, \
+            'The length of enc_num_convs should be equal to num_stages, '\
+            f'while the enc_num_convs is {enc_num_convs}, the length of '\
+            f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
+            f'{num_stages}.'
+        assert len(dec_num_convs) == (num_stages-1), \
+            'The length of dec_num_convs should be equal to (num_stages-1), '\
+            f'while the dec_num_convs is {dec_num_convs}, the length of '\
+            f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
+            f'{num_stages}.'
+        assert len(downsamples) == (num_stages-1), \
+            'The length of downsamples should be equal to (num_stages-1), '\
+            f'while the downsamples is {downsamples}, the length of '\
+            f'downsamples is {len(downsamples)}, and the num_stages is '\
+            f'{num_stages}.'
+        assert len(enc_dilations) == num_stages, \
+            'The length of enc_dilations should be equal to num_stages, '\
+            f'while the enc_dilations is {enc_dilations}, the length of '\
+            f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
+            f'{num_stages}.'
+        assert len(dec_dilations) == (num_stages-1), \
+            'The length of dec_dilations should be equal to (num_stages-1), '\
+            f'while the dec_dilations is {dec_dilations}, the length of '\
+            f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
+            f'{num_stages}.'
+        self.num_stages = num_stages
+        self.strides = strides
+        self.downsamples = downsamples
+        self.norm_eval = norm_eval
+        self.base_channels = base_channels
+
+        self.encoder = nn.ModuleList()
+        self.decoder = nn.ModuleList()
+
+        for i in range(num_stages):
+            enc_conv_block = []
+            if i != 0:
+                if strides[i] == 1 and downsamples[i - 1]:
+                    enc_conv_block.append(nn.MaxPool2d(kernel_size=2))
+                upsample = (strides[i] != 1 or downsamples[i - 1])
+                self.decoder.append(
+                    UpConvBlock(
+                        conv_block=BasicConvBlock,
+                        in_channels=base_channels * 2**i,
+                        skip_channels=base_channels * 2**(i - 1),
+                        out_channels=base_channels * 2**(i - 1),
+                        num_convs=dec_num_convs[i - 1],
+                        stride=1,
+                        dilation=dec_dilations[i - 1],
+                        with_cp=with_cp,
+                        conv_cfg=conv_cfg,
+                        norm_cfg=norm_cfg,
+                        act_cfg=act_cfg,
+                        upsample_cfg=upsample_cfg if upsample else None,
+                        dcn=None,
+                        plugins=None))
+
+            enc_conv_block.append(
+                BasicConvBlock(
+                    in_channels=in_channels,
+                    out_channels=base_channels * 2**i,
+                    num_convs=enc_num_convs[i],
+                    stride=strides[i],
+                    dilation=enc_dilations[i],
+                    with_cp=with_cp,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg,
+                    dcn=None,
+                    plugins=None))
+            self.encoder.append((nn.Sequential(*enc_conv_block)))
+            in_channels = base_channels * 2**i
+
+    def forward(self, x):
+        self._check_input_divisible(x)
+        enc_outs = []
+        for enc in self.encoder:
+            x = enc(x)
+            enc_outs.append(x)
+        dec_outs = [x]
+        for i in reversed(range(len(self.decoder))):
+            x = self.decoder[i](enc_outs[i], x)
+            dec_outs.append(x)
+
+        return dec_outs
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep normalization layer
+        freezed."""
+        super(UNet, self).train(mode)
+        if mode and self.norm_eval:
+            for m in self.modules():
+                # trick: eval have effect on BatchNorm only
+                if isinstance(m, _BatchNorm):
+                    m.eval()
+
+    def _check_input_divisible(self, x):
+        h, w = x.shape[-2:]
+        whole_downsample_rate = 1
+        for i in range(1, self.num_stages):
+            if self.strides[i] == 2 or self.downsamples[i - 1]:
+                whole_downsample_rate *= 2
+        assert (h % whole_downsample_rate == 0) \
+            and (w % whole_downsample_rate == 0),\
+            f'The input image size {(h, w)} should be divisible by the whole '\
+            f'downsample rate {whole_downsample_rate}, when num_stages is '\
+            f'{self.num_stages}, strides is {self.strides}, and downsamples '\
+            f'is {self.downsamples}.'
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        if isinstance(pretrained, str):
+            logger = get_root_logger()
+            load_checkpoint(self, pretrained, strict=False, logger=logger)
+        elif pretrained is None:
+            for m in self.modules():
+                if isinstance(m, nn.Conv2d):
+                    kaiming_init(m)
+                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+                    constant_init(m, 1)
+        else:
+            raise TypeError('pretrained must be a str or None')
diff --git a/annotator/uniformer/mmseg/models/backbones/uniformer.py b/annotator/uniformer/mmseg/models/backbones/uniformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c4bb88e4c928540cca9ab609988b916520f5b7a
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/uniformer.py
@@ -0,0 +1,422 @@
+# --------------------------------------------------------
+# UniFormer
+# Copyright (c) 2022 SenseTime X-Lab
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Kunchang Li
+# --------------------------------------------------------
+
+from collections import OrderedDict
+import math
+
+from functools import partial
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from annotator.uniformer.mmcv_custom import load_checkpoint
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class CMlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+        self.act = act_layer()
+        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class CBlock(nn.Module):
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+        self.norm1 = nn.BatchNorm2d(dim)
+        self.conv1 = nn.Conv2d(dim, dim, 1)
+        self.conv2 = nn.Conv2d(dim, dim, 1)
+        self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = nn.BatchNorm2d(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, x):
+        x = x + self.pos_embed(x)
+        x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class Attention(nn.Module):
+    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+        self.scale = qk_scale or head_dim ** -0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x):
+        B, N, C = x.shape
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class SABlock(nn.Module):
+    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+            attn_drop=attn_drop, proj_drop=drop)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, x):
+        x = x + self.pos_embed(x)
+        B, N, H, W = x.shape
+        x = x.flatten(2).transpose(1, 2)
+        x = x + self.drop_path(self.attn(self.norm1(x)))
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        x = x.transpose(1, 2).reshape(B, N, H, W)
+        return x   
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class SABlock_Windows(nn.Module):
+    def __init__(self, dim, num_heads, window_size=14, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.window_size=window_size
+        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
+            attn_drop=attn_drop, proj_drop=drop)
+        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, x):
+        x = x + self.pos_embed(x)
+        x = x.permute(0, 2, 3, 1)
+        B, H, W, C = x.shape
+        shortcut = x
+        x = self.norm1(x)
+
+        pad_l = pad_t = 0
+        pad_r = (self.window_size - W % self.window_size) % self.window_size
+        pad_b = (self.window_size - H % self.window_size) % self.window_size
+        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+        _, Hp, Wp, _ = x.shape
+        
+        x_windows = window_partition(x, self.window_size)  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(x_windows)  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+        x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
+
+        # reverse cyclic shift
+        if pad_r > 0 or pad_b > 0:
+            x = x[:, :H, :W, :].contiguous()
+
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        x = x.permute(0, 3, 1, 2).reshape(B, C, H, W)
+        return x 
+             
+
+class PatchEmbed(nn.Module):
+    """ Image to Patch Embedding
+    """
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.num_patches = num_patches
+        self.norm = nn.LayerNorm(embed_dim)
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x):
+        B, _, H, W = x.shape
+        x = self.proj(x)
+        B, _, H, W = x.shape
+        x = x.flatten(2).transpose(1, 2)
+        x = self.norm(x)
+        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+        return x
+    
+
+@BACKBONES.register_module()   
+class UniFormer(nn.Module):
+    """ Vision Transformer
+    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`  -
+        https://arxiv.org/abs/2010.11929
+    """
+    def __init__(self, layers=[3, 4, 8, 3], img_size=224, in_chans=3, num_classes=80, embed_dim=[64, 128, 320, 512],
+                 head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
+                 pretrained_path=None, use_checkpoint=False, checkpoint_num=[0, 0, 0, 0], 
+                 windows=False, hybrid=False, window_size=14):
+        """
+        Args:
+            layer (list): number of block in each layer
+            img_size (int, tuple): input image size
+            in_chans (int): number of input channels
+            num_classes (int): number of classes for classification head
+            embed_dim (int): embedding dimension
+            head_dim (int): dimension of attention heads
+            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+            qkv_bias (bool): enable bias for qkv if True
+            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+            drop_rate (float): dropout rate
+            attn_drop_rate (float): attention dropout rate
+            drop_path_rate (float): stochastic depth rate
+            norm_layer (nn.Module): normalization layer
+            pretrained_path (str): path of pretrained model
+            use_checkpoint (bool): whether use checkpoint
+            checkpoint_num (list): index for using checkpoint in every stage
+            windows (bool): whether use window MHRA
+            hybrid (bool): whether use hybrid MHRA
+            window_size (int): size of window (>14)
+        """
+        super().__init__()
+        self.num_classes = num_classes
+        self.use_checkpoint = use_checkpoint
+        self.checkpoint_num = checkpoint_num
+        self.windows = windows
+        print(f'Use Checkpoint: {self.use_checkpoint}')
+        print(f'Checkpoint Number: {self.checkpoint_num}')
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 
+        
+        self.patch_embed1 = PatchEmbed(
+            img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
+        self.patch_embed2 = PatchEmbed(
+            img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1])
+        self.patch_embed3 = PatchEmbed(
+            img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2])
+        self.patch_embed4 = PatchEmbed(
+            img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3])
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))]  # stochastic depth decay rule
+        num_heads = [dim // head_dim for dim in embed_dim]
+        self.blocks1 = nn.ModuleList([
+            CBlock(
+                dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
+            for i in range(layers[0])])
+        self.norm1=norm_layer(embed_dim[0])
+        self.blocks2 = nn.ModuleList([
+            CBlock(
+                dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]], norm_layer=norm_layer)
+            for i in range(layers[1])])
+        self.norm2 = norm_layer(embed_dim[1])
+        if self.windows:
+            print('Use local window for all blocks in stage3')
+            self.blocks3 = nn.ModuleList([
+            SABlock_Windows(
+                dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
+            for i in range(layers[2])])
+        elif hybrid:
+            print('Use hybrid window for blocks in stage3')
+            block3 = []
+            for i in range(layers[2]):
+                if (i + 1) % 4 == 0:
+                    block3.append(SABlock(
+                    dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                    drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
+                else:
+                    block3.append(SABlock_Windows(
+                    dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                    drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer))
+            self.blocks3 = nn.ModuleList(block3)
+        else:
+            print('Use global window for all blocks in stage3')
+            self.blocks3 = nn.ModuleList([
+            SABlock(
+                dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)
+            for i in range(layers[2])])
+        self.norm3 = norm_layer(embed_dim[2])
+        self.blocks4 = nn.ModuleList([
+            SABlock(
+                dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]+layers[2]], norm_layer=norm_layer)
+            for i in range(layers[3])])
+        self.norm4 = norm_layer(embed_dim[3])
+        
+        # Representation layer
+        if representation_size:
+            self.num_features = representation_size
+            self.pre_logits = nn.Sequential(OrderedDict([
+                ('fc', nn.Linear(embed_dim, representation_size)),
+                ('act', nn.Tanh())
+            ]))
+        else:
+            self.pre_logits = nn.Identity()
+        
+        self.apply(self._init_weights)
+        self.init_weights(pretrained=pretrained_path)
+        
+    def init_weights(self, pretrained):
+        if isinstance(pretrained, str):
+            logger = get_root_logger()
+            load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
+            print(f'Load pretrained model from {pretrained}')
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'pos_embed', 'cls_token'}
+
+    def get_classifier(self):
+        return self.head
+
+    def reset_classifier(self, num_classes, global_pool=''):
+        self.num_classes = num_classes
+        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+    def forward_features(self, x):
+        out = []
+        x = self.patch_embed1(x)
+        x = self.pos_drop(x)
+        for i, blk in enumerate(self.blocks1):
+            if self.use_checkpoint and i < self.checkpoint_num[0]:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        x_out = self.norm1(x.permute(0, 2, 3, 1))
+        out.append(x_out.permute(0, 3, 1, 2).contiguous())
+        x = self.patch_embed2(x)
+        for i, blk in enumerate(self.blocks2):
+            if self.use_checkpoint and i < self.checkpoint_num[1]:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        x_out = self.norm2(x.permute(0, 2, 3, 1))
+        out.append(x_out.permute(0, 3, 1, 2).contiguous())
+        x = self.patch_embed3(x)
+        for i, blk in enumerate(self.blocks3):
+            if self.use_checkpoint and i < self.checkpoint_num[2]:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        x_out = self.norm3(x.permute(0, 2, 3, 1))
+        out.append(x_out.permute(0, 3, 1, 2).contiguous())
+        x = self.patch_embed4(x)
+        for i, blk in enumerate(self.blocks4):
+            if self.use_checkpoint and i < self.checkpoint_num[3]:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        x_out = self.norm4(x.permute(0, 2, 3, 1))
+        out.append(x_out.permute(0, 3, 1, 2).contiguous())
+        return tuple(out)
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        return x
diff --git a/annotator/uniformer/mmseg/models/backbones/vit.py b/annotator/uniformer/mmseg/models/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..59e4479650690e08cbc4cab9427aefda47c2116d
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/backbones/vit.py
@@ -0,0 +1,459 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/vision_transformer.py."""
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from annotator.uniformer.mmcv.cnn import (Conv2d, Linear, build_activation_layer, build_norm_layer,
+                      constant_init, kaiming_init, normal_init)
+from annotator.uniformer.mmcv.runner import _load_checkpoint
+from annotator.uniformer.mmcv.utils.parrots_wrapper import _BatchNorm
+
+from annotator.uniformer.mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import DropPath, trunc_normal_
+
+
+class Mlp(nn.Module):
+    """MLP layer for Encoder block.
+
+    Args:
+        in_features(int): Input dimension for the first fully
+            connected layer.
+        hidden_features(int): Output dimension for the first fully
+            connected layer.
+        out_features(int): Output dementsion for the second fully
+            connected layer.
+        act_cfg(dict): Config dict for activation layer.
+            Default: dict(type='GELU').
+        drop(float): Drop rate for the dropout layer. Dropout rate has
+            to be between 0 and 1. Default: 0.
+    """
+
+    def __init__(self,
+                 in_features,
+                 hidden_features=None,
+                 out_features=None,
+                 act_cfg=dict(type='GELU'),
+                 drop=0.):
+        super(Mlp, self).__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = Linear(in_features, hidden_features)
+        self.act = build_activation_layer(act_cfg)
+        self.fc2 = Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(nn.Module):
+    """Attention layer for Encoder block.
+
+    Args:
+        dim (int): Dimension for the input vector.
+        num_heads (int): Number of parallel attention heads.
+        qkv_bias (bool): Enable bias for qkv if True. Default: False.
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+        attn_drop (float): Drop rate for attention output weights.
+            Default: 0.
+        proj_drop (float): Drop rate for output weights. Default: 0.
+    """
+
+    def __init__(self,
+                 dim,
+                 num_heads=8,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 attn_drop=0.,
+                 proj_drop=0.):
+        super(Attention, self).__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x):
+        b, n, c = x.shape
+        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads,
+                                  c // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]
+
+        attn = (q @ k.transpose(-2, -1)) * self.scale
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(b, n, c)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class Block(nn.Module):
+    """Implements encoder block with residual connection.
+
+    Args:
+        dim (int): The feature dimension.
+        num_heads (int): Number of parallel attention heads.
+        mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float): Drop rate for mlp output weights. Default: 0.
+        attn_drop (float): Drop rate for attention output weights.
+            Default: 0.
+        proj_drop (float): Drop rate for attn layer output weights.
+            Default: 0.
+        drop_path (float): Drop rate for paths of model.
+            Default: 0.
+        act_cfg (dict): Config dict for activation layer.
+            Default: dict(type='GELU').
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='LN', requires_grad=True).
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+    """
+
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 mlp_ratio=4,
+                 qkv_bias=False,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 proj_drop=0.,
+                 drop_path=0.,
+                 act_cfg=dict(type='GELU'),
+                 norm_cfg=dict(type='LN', eps=1e-6),
+                 with_cp=False):
+        super(Block, self).__init__()
+        self.with_cp = with_cp
+        _, self.norm1 = build_norm_layer(norm_cfg, dim)
+        self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop,
+                              proj_drop)
+        self.drop_path = DropPath(
+            drop_path) if drop_path > 0. else nn.Identity()
+        _, self.norm2 = build_norm_layer(norm_cfg, dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(
+            in_features=dim,
+            hidden_features=mlp_hidden_dim,
+            act_cfg=act_cfg,
+            drop=drop)
+
+    def forward(self, x):
+
+        def _inner_forward(x):
+            out = x + self.drop_path(self.attn(self.norm1(x)))
+            out = out + self.drop_path(self.mlp(self.norm2(out)))
+            return out
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        return out
+
+
+class PatchEmbed(nn.Module):
+    """Image to Patch Embedding.
+
+    Args:
+        img_size (int | tuple): Input image size.
+            default: 224.
+        patch_size (int): Width and height for a patch.
+            default: 16.
+        in_channels (int): Input channels for images. Default: 3.
+        embed_dim (int): The embedding dimension. Default: 768.
+    """
+
+    def __init__(self,
+                 img_size=224,
+                 patch_size=16,
+                 in_channels=3,
+                 embed_dim=768):
+        super(PatchEmbed, self).__init__()
+        if isinstance(img_size, int):
+            self.img_size = (img_size, img_size)
+        elif isinstance(img_size, tuple):
+            self.img_size = img_size
+        else:
+            raise TypeError('img_size must be type of int or tuple')
+        h, w = self.img_size
+        self.patch_size = (patch_size, patch_size)
+        self.num_patches = (h // patch_size) * (w // patch_size)
+        self.proj = Conv2d(
+            in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+    def forward(self, x):
+        return self.proj(x).flatten(2).transpose(1, 2)
+
+
+@BACKBONES.register_module()
+class VisionTransformer(nn.Module):
+    """Vision transformer backbone.
+
+    A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for
+        Image Recognition at Scale` - https://arxiv.org/abs/2010.11929
+
+    Args:
+        img_size (tuple): input image size. Default: (224, 224).
+        patch_size (int, tuple): patch size. Default: 16.
+        in_channels (int): number of input channels. Default: 3.
+        embed_dim (int): embedding dimension. Default: 768.
+        depth (int): depth of transformer. Default: 12.
+        num_heads (int): number of attention heads. Default: 12.
+        mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
+            Default: 4.
+        out_indices (list | tuple | int): Output from which stages.
+            Default: -1.
+        qkv_bias (bool): enable bias for qkv if True. Default: True.
+        qk_scale (float): override default qk scale of head_dim ** -0.5 if set.
+        drop_rate (float): dropout rate. Default: 0.
+        attn_drop_rate (float): attention dropout rate. Default: 0.
+        drop_path_rate (float): Rate of DropPath. Default: 0.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='LN', eps=1e-6, requires_grad=True).
+        act_cfg (dict): Config dict for activation layer.
+            Default: dict(type='GELU').
+        norm_eval (bool): Whether to set norm layers to eval mode, namely,
+            freeze running stats (mean and var). Note: Effect on Batch Norm
+            and its variants only. Default: False.
+        final_norm (bool):  Whether to add a additional layer to normalize
+            final feature map. Default: False.
+        interpolate_mode (str): Select the interpolate mode for position
+            embeding vector resize. Default: bicubic.
+        with_cls_token (bool): If concatenating class token into image tokens
+            as transformer input. Default: True.
+        with_cp (bool): Use checkpoint or not. Using checkpoint
+            will save some memory while slowing down the training speed.
+            Default: False.
+    """
+
+    def __init__(self,
+                 img_size=(224, 224),
+                 patch_size=16,
+                 in_channels=3,
+                 embed_dim=768,
+                 depth=12,
+                 num_heads=12,
+                 mlp_ratio=4,
+                 out_indices=11,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.,
+                 norm_cfg=dict(type='LN', eps=1e-6, requires_grad=True),
+                 act_cfg=dict(type='GELU'),
+                 norm_eval=False,
+                 final_norm=False,
+                 with_cls_token=True,
+                 interpolate_mode='bicubic',
+                 with_cp=False):
+        super(VisionTransformer, self).__init__()
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.features = self.embed_dim = embed_dim
+        self.patch_embed = PatchEmbed(
+            img_size=img_size,
+            patch_size=patch_size,
+            in_channels=in_channels,
+            embed_dim=embed_dim)
+
+        self.with_cls_token = with_cls_token
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+        self.pos_embed = nn.Parameter(
+            torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        if isinstance(out_indices, int):
+            self.out_indices = [out_indices]
+        elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
+            self.out_indices = out_indices
+        else:
+            raise TypeError('out_indices must be type of int, list or tuple')
+
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
+               ]  # stochastic depth decay rule
+        self.blocks = nn.ModuleList([
+            Block(
+                dim=embed_dim,
+                num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=dpr[i],
+                attn_drop=attn_drop_rate,
+                act_cfg=act_cfg,
+                norm_cfg=norm_cfg,
+                with_cp=with_cp) for i in range(depth)
+        ])
+
+        self.interpolate_mode = interpolate_mode
+        self.final_norm = final_norm
+        if final_norm:
+            _, self.norm = build_norm_layer(norm_cfg, embed_dim)
+
+        self.norm_eval = norm_eval
+        self.with_cp = with_cp
+
+    def init_weights(self, pretrained=None):
+        if isinstance(pretrained, str):
+            logger = get_root_logger()
+            checkpoint = _load_checkpoint(pretrained, logger=logger)
+            if 'state_dict' in checkpoint:
+                state_dict = checkpoint['state_dict']
+            else:
+                state_dict = checkpoint
+
+            if 'pos_embed' in state_dict.keys():
+                if self.pos_embed.shape != state_dict['pos_embed'].shape:
+                    logger.info(msg=f'Resize the pos_embed shape from \
+{state_dict["pos_embed"].shape} to {self.pos_embed.shape}')
+                    h, w = self.img_size
+                    pos_size = int(
+                        math.sqrt(state_dict['pos_embed'].shape[1] - 1))
+                    state_dict['pos_embed'] = self.resize_pos_embed(
+                        state_dict['pos_embed'], (h, w), (pos_size, pos_size),
+                        self.patch_size, self.interpolate_mode)
+
+            self.load_state_dict(state_dict, False)
+
+        elif pretrained is None:
+            # We only implement the 'jax_impl' initialization implemented at
+            # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353  # noqa: E501
+            trunc_normal_(self.pos_embed, std=.02)
+            trunc_normal_(self.cls_token, std=.02)
+            for n, m in self.named_modules():
+                if isinstance(m, Linear):
+                    trunc_normal_(m.weight, std=.02)
+                    if m.bias is not None:
+                        if 'mlp' in n:
+                            normal_init(m.bias, std=1e-6)
+                        else:
+                            constant_init(m.bias, 0)
+                elif isinstance(m, Conv2d):
+                    kaiming_init(m.weight, mode='fan_in')
+                    if m.bias is not None:
+                        constant_init(m.bias, 0)
+                elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+                    constant_init(m.bias, 0)
+                    constant_init(m.weight, 1.0)
+        else:
+            raise TypeError('pretrained must be a str or None')
+
+    def _pos_embeding(self, img, patched_img, pos_embed):
+        """Positiong embeding method.
+
+        Resize the pos_embed, if the input image size doesn't match
+            the training size.
+        Args:
+            img (torch.Tensor): The inference image tensor, the shape
+                must be [B, C, H, W].
+            patched_img (torch.Tensor): The patched image, it should be
+                shape of [B, L1, C].
+            pos_embed (torch.Tensor): The pos_embed weighs, it should be
+                shape of [B, L2, c].
+        Return:
+            torch.Tensor: The pos encoded image feature.
+        """
+        assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
+            'the shapes of patched_img and pos_embed must be [B, L, C]'
+        x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]
+        if x_len != pos_len:
+            if pos_len == (self.img_size[0] // self.patch_size) * (
+                    self.img_size[1] // self.patch_size) + 1:
+                pos_h = self.img_size[0] // self.patch_size
+                pos_w = self.img_size[1] // self.patch_size
+            else:
+                raise ValueError(
+                    'Unexpected shape of pos_embed, got {}.'.format(
+                        pos_embed.shape))
+            pos_embed = self.resize_pos_embed(pos_embed, img.shape[2:],
+                                              (pos_h, pos_w), self.patch_size,
+                                              self.interpolate_mode)
+        return self.pos_drop(patched_img + pos_embed)
+
+    @staticmethod
+    def resize_pos_embed(pos_embed, input_shpae, pos_shape, patch_size, mode):
+        """Resize pos_embed weights.
+
+        Resize pos_embed using bicubic interpolate method.
+        Args:
+            pos_embed (torch.Tensor): pos_embed weights.
+            input_shpae (tuple): Tuple for (input_h, intput_w).
+            pos_shape (tuple): Tuple for (pos_h, pos_w).
+            patch_size (int): Patch size.
+        Return:
+            torch.Tensor: The resized pos_embed of shape [B, L_new, C]
+        """
+        assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]'
+        input_h, input_w = input_shpae
+        pos_h, pos_w = pos_shape
+        cls_token_weight = pos_embed[:, 0]
+        pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):]
+        pos_embed_weight = pos_embed_weight.reshape(
+            1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
+        pos_embed_weight = F.interpolate(
+            pos_embed_weight,
+            size=[input_h // patch_size, input_w // patch_size],
+            align_corners=False,
+            mode=mode)
+        cls_token_weight = cls_token_weight.unsqueeze(1)
+        pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
+        pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
+        return pos_embed
+
+    def forward(self, inputs):
+        B = inputs.shape[0]
+
+        x = self.patch_embed(inputs)
+
+        cls_tokens = self.cls_token.expand(B, -1, -1)
+        x = torch.cat((cls_tokens, x), dim=1)
+        x = self._pos_embeding(inputs, x, self.pos_embed)
+
+        if not self.with_cls_token:
+            # Remove class token for transformer input
+            x = x[:, 1:]
+
+        outs = []
+        for i, blk in enumerate(self.blocks):
+            x = blk(x)
+            if i == len(self.blocks) - 1:
+                if self.final_norm:
+                    x = self.norm(x)
+            if i in self.out_indices:
+                if self.with_cls_token:
+                    # Remove class token and reshape token for decoder head
+                    out = x[:, 1:]
+                else:
+                    out = x
+                B, _, C = out.shape
+                out = out.reshape(B, inputs.shape[2] // self.patch_size,
+                                  inputs.shape[3] // self.patch_size,
+                                  C).permute(0, 3, 1, 2)
+                outs.append(out)
+
+        return tuple(outs)
+
+    def train(self, mode=True):
+        super(VisionTransformer, self).train(mode)
+        if mode and self.norm_eval:
+            for m in self.modules():
+                if isinstance(m, nn.LayerNorm):
+                    m.eval()
diff --git a/annotator/uniformer/mmseg/models/builder.py b/annotator/uniformer/mmseg/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f5b971252bfc971c3ffbaa27746d69b1d3ea9fd
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/builder.py
@@ -0,0 +1,46 @@
+import warnings
+
+from annotator.uniformer.mmcv.cnn import MODELS as MMCV_MODELS
+from annotator.uniformer.mmcv.utils import Registry
+
+MODELS = Registry('models', parent=MMCV_MODELS)
+
+BACKBONES = MODELS
+NECKS = MODELS
+HEADS = MODELS
+LOSSES = MODELS
+SEGMENTORS = MODELS
+
+
+def build_backbone(cfg):
+    """Build backbone."""
+    return BACKBONES.build(cfg)
+
+
+def build_neck(cfg):
+    """Build neck."""
+    return NECKS.build(cfg)
+
+
+def build_head(cfg):
+    """Build head."""
+    return HEADS.build(cfg)
+
+
+def build_loss(cfg):
+    """Build loss."""
+    return LOSSES.build(cfg)
+
+
+def build_segmentor(cfg, train_cfg=None, test_cfg=None):
+    """Build segmentor."""
+    if train_cfg is not None or test_cfg is not None:
+        warnings.warn(
+            'train_cfg and test_cfg is deprecated, '
+            'please specify them in model', UserWarning)
+    assert cfg.get('train_cfg') is None or train_cfg is None, \
+        'train_cfg specified in both outer field and model field '
+    assert cfg.get('test_cfg') is None or test_cfg is None, \
+        'test_cfg specified in both outer field and model field '
+    return SEGMENTORS.build(
+        cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/annotator/uniformer/mmseg/models/decode_heads/__init__.py b/annotator/uniformer/mmseg/models/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac66d3cfe0ea04af45c0f3594bf135841c3812e3
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/__init__.py
@@ -0,0 +1,28 @@
+from .ann_head import ANNHead
+from .apc_head import APCHead
+from .aspp_head import ASPPHead
+from .cc_head import CCHead
+from .da_head import DAHead
+from .dm_head import DMHead
+from .dnl_head import DNLHead
+from .ema_head import EMAHead
+from .enc_head import EncHead
+from .fcn_head import FCNHead
+from .fpn_head import FPNHead
+from .gc_head import GCHead
+from .lraspp_head import LRASPPHead
+from .nl_head import NLHead
+from .ocr_head import OCRHead
+# from .point_head import PointHead
+from .psa_head import PSAHead
+from .psp_head import PSPHead
+from .sep_aspp_head import DepthwiseSeparableASPPHead
+from .sep_fcn_head import DepthwiseSeparableFCNHead
+from .uper_head import UPerHead
+
+__all__ = [
+    'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
+    'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
+    'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
+    'APCHead', 'DMHead', 'LRASPPHead'
+]
diff --git a/annotator/uniformer/mmseg/models/decode_heads/ann_head.py b/annotator/uniformer/mmseg/models/decode_heads/ann_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..30aaacc2cafc568d3de71d1477b4de0dc0fea9d3
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/ann_head.py
@@ -0,0 +1,245 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+
+
+class PPMConcat(nn.ModuleList):
+    """Pyramid Pooling Module that only concat the features of each layer.
+
+    Args:
+        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+            Module.
+    """
+
+    def __init__(self, pool_scales=(1, 3, 6, 8)):
+        super(PPMConcat, self).__init__(
+            [nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales])
+
+    def forward(self, feats):
+        """Forward function."""
+        ppm_outs = []
+        for ppm in self:
+            ppm_out = ppm(feats)
+            ppm_outs.append(ppm_out.view(*feats.shape[:2], -1))
+        concat_outs = torch.cat(ppm_outs, dim=2)
+        return concat_outs
+
+
+class SelfAttentionBlock(_SelfAttentionBlock):
+    """Make a ANN used SelfAttentionBlock.
+
+    Args:
+        low_in_channels (int): Input channels of lower level feature,
+            which is the key feature for self-attention.
+        high_in_channels (int): Input channels of higher level feature,
+            which is the query feature for self-attention.
+        channels (int): Output channels of key/query transform.
+        out_channels (int): Output channels.
+        share_key_query (bool): Whether share projection weight between key
+            and query projection.
+        query_scale (int): The scale of query feature map.
+        key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+            Module of key feature.
+        conv_cfg (dict|None): Config of conv layers.
+        norm_cfg (dict|None): Config of norm layers.
+        act_cfg (dict|None): Config of activation layers.
+    """
+
+    def __init__(self, low_in_channels, high_in_channels, channels,
+                 out_channels, share_key_query, query_scale, key_pool_scales,
+                 conv_cfg, norm_cfg, act_cfg):
+        key_psp = PPMConcat(key_pool_scales)
+        if query_scale > 1:
+            query_downsample = nn.MaxPool2d(kernel_size=query_scale)
+        else:
+            query_downsample = None
+        super(SelfAttentionBlock, self).__init__(
+            key_in_channels=low_in_channels,
+            query_in_channels=high_in_channels,
+            channels=channels,
+            out_channels=out_channels,
+            share_key_query=share_key_query,
+            query_downsample=query_downsample,
+            key_downsample=key_psp,
+            key_query_num_convs=1,
+            key_query_norm=True,
+            value_out_num_convs=1,
+            value_out_norm=False,
+            matmul_norm=True,
+            with_out=True,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg)
+
+
+class AFNB(nn.Module):
+    """Asymmetric Fusion Non-local Block(AFNB)
+
+    Args:
+        low_in_channels (int): Input channels of lower level feature,
+            which is the key feature for self-attention.
+        high_in_channels (int): Input channels of higher level feature,
+            which is the query feature for self-attention.
+        channels (int): Output channels of key/query transform.
+        out_channels (int): Output channels.
+            and query projection.
+        query_scales (tuple[int]): The scales of query feature map.
+            Default: (1,)
+        key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+            Module of key feature.
+        conv_cfg (dict|None): Config of conv layers.
+        norm_cfg (dict|None): Config of norm layers.
+        act_cfg (dict|None): Config of activation layers.
+    """
+
+    def __init__(self, low_in_channels, high_in_channels, channels,
+                 out_channels, query_scales, key_pool_scales, conv_cfg,
+                 norm_cfg, act_cfg):
+        super(AFNB, self).__init__()
+        self.stages = nn.ModuleList()
+        for query_scale in query_scales:
+            self.stages.append(
+                SelfAttentionBlock(
+                    low_in_channels=low_in_channels,
+                    high_in_channels=high_in_channels,
+                    channels=channels,
+                    out_channels=out_channels,
+                    share_key_query=False,
+                    query_scale=query_scale,
+                    key_pool_scales=key_pool_scales,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+        self.bottleneck = ConvModule(
+            out_channels + high_in_channels,
+            out_channels,
+            1,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=None)
+
+    def forward(self, low_feats, high_feats):
+        """Forward function."""
+        priors = [stage(high_feats, low_feats) for stage in self.stages]
+        context = torch.stack(priors, dim=0).sum(dim=0)
+        output = self.bottleneck(torch.cat([context, high_feats], 1))
+        return output
+
+
+class APNB(nn.Module):
+    """Asymmetric Pyramid Non-local Block (APNB)
+
+    Args:
+        in_channels (int): Input channels of key/query feature,
+            which is the key feature for self-attention.
+        channels (int): Output channels of key/query transform.
+        out_channels (int): Output channels.
+        query_scales (tuple[int]): The scales of query feature map.
+            Default: (1,)
+        key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+            Module of key feature.
+        conv_cfg (dict|None): Config of conv layers.
+        norm_cfg (dict|None): Config of norm layers.
+        act_cfg (dict|None): Config of activation layers.
+    """
+
+    def __init__(self, in_channels, channels, out_channels, query_scales,
+                 key_pool_scales, conv_cfg, norm_cfg, act_cfg):
+        super(APNB, self).__init__()
+        self.stages = nn.ModuleList()
+        for query_scale in query_scales:
+            self.stages.append(
+                SelfAttentionBlock(
+                    low_in_channels=in_channels,
+                    high_in_channels=in_channels,
+                    channels=channels,
+                    out_channels=out_channels,
+                    share_key_query=True,
+                    query_scale=query_scale,
+                    key_pool_scales=key_pool_scales,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+        self.bottleneck = ConvModule(
+            2 * in_channels,
+            out_channels,
+            1,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg)
+
+    def forward(self, feats):
+        """Forward function."""
+        priors = [stage(feats, feats) for stage in self.stages]
+        context = torch.stack(priors, dim=0).sum(dim=0)
+        output = self.bottleneck(torch.cat([context, feats], 1))
+        return output
+
+
+@HEADS.register_module()
+class ANNHead(BaseDecodeHead):
+    """Asymmetric Non-local Neural Networks for Semantic Segmentation.
+
+    This head is the implementation of `ANNNet
+    <https://arxiv.org/abs/1908.07678>`_.
+
+    Args:
+        project_channels (int): Projection channels for Nonlocal.
+        query_scales (tuple[int]): The scales of query feature map.
+            Default: (1,)
+        key_pool_scales (tuple[int]): The pooling scales of key feature map.
+            Default: (1, 3, 6, 8).
+    """
+
+    def __init__(self,
+                 project_channels,
+                 query_scales=(1, ),
+                 key_pool_scales=(1, 3, 6, 8),
+                 **kwargs):
+        super(ANNHead, self).__init__(
+            input_transform='multiple_select', **kwargs)
+        assert len(self.in_channels) == 2
+        low_in_channels, high_in_channels = self.in_channels
+        self.project_channels = project_channels
+        self.fusion = AFNB(
+            low_in_channels=low_in_channels,
+            high_in_channels=high_in_channels,
+            out_channels=high_in_channels,
+            channels=project_channels,
+            query_scales=query_scales,
+            key_pool_scales=key_pool_scales,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.bottleneck = ConvModule(
+            high_in_channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.context = APNB(
+            in_channels=self.channels,
+            out_channels=self.channels,
+            channels=project_channels,
+            query_scales=query_scales,
+            key_pool_scales=key_pool_scales,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def forward(self, inputs):
+        """Forward function."""
+        low_feats, high_feats = self._transform_inputs(inputs)
+        output = self.fusion(low_feats, high_feats)
+        output = self.dropout(output)
+        output = self.bottleneck(output)
+        output = self.context(output)
+        output = self.cls_seg(output)
+
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/apc_head.py b/annotator/uniformer/mmseg/models/decode_heads/apc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7038bdbe0edf2a1f184b6899486d2d190dda076
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/apc_head.py
@@ -0,0 +1,158 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class ACM(nn.Module):
+    """Adaptive Context Module used in APCNet.
+
+    Args:
+        pool_scale (int): Pooling scale used in Adaptive Context
+            Module to extract region features.
+        fusion (bool): Add one conv to fuse residual feature.
+        in_channels (int): Input channels.
+        channels (int): Channels after modules, before conv_seg.
+        conv_cfg (dict | None): Config of conv layers.
+        norm_cfg (dict | None): Config of norm layers.
+        act_cfg (dict): Config of activation layers.
+    """
+
+    def __init__(self, pool_scale, fusion, in_channels, channels, conv_cfg,
+                 norm_cfg, act_cfg):
+        super(ACM, self).__init__()
+        self.pool_scale = pool_scale
+        self.fusion = fusion
+        self.in_channels = in_channels
+        self.channels = channels
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        self.pooled_redu_conv = ConvModule(
+            self.in_channels,
+            self.channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+        self.input_redu_conv = ConvModule(
+            self.in_channels,
+            self.channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+        self.global_info = ConvModule(
+            self.channels,
+            self.channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+        self.gla = nn.Conv2d(self.channels, self.pool_scale**2, 1, 1, 0)
+
+        self.residual_conv = ConvModule(
+            self.channels,
+            self.channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+        if self.fusion:
+            self.fusion_conv = ConvModule(
+                self.channels,
+                self.channels,
+                1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg)
+
+    def forward(self, x):
+        """Forward function."""
+        pooled_x = F.adaptive_avg_pool2d(x, self.pool_scale)
+        # [batch_size, channels, h, w]
+        x = self.input_redu_conv(x)
+        # [batch_size, channels, pool_scale, pool_scale]
+        pooled_x = self.pooled_redu_conv(pooled_x)
+        batch_size = x.size(0)
+        # [batch_size, pool_scale * pool_scale, channels]
+        pooled_x = pooled_x.view(batch_size, self.channels,
+                                 -1).permute(0, 2, 1).contiguous()
+        # [batch_size, h * w, pool_scale * pool_scale]
+        affinity_matrix = self.gla(x + resize(
+            self.global_info(F.adaptive_avg_pool2d(x, 1)), size=x.shape[2:])
+                                   ).permute(0, 2, 3, 1).reshape(
+                                       batch_size, -1, self.pool_scale**2)
+        affinity_matrix = F.sigmoid(affinity_matrix)
+        # [batch_size, h * w, channels]
+        z_out = torch.matmul(affinity_matrix, pooled_x)
+        # [batch_size, channels, h * w]
+        z_out = z_out.permute(0, 2, 1).contiguous()
+        # [batch_size, channels, h, w]
+        z_out = z_out.view(batch_size, self.channels, x.size(2), x.size(3))
+        z_out = self.residual_conv(z_out)
+        z_out = F.relu(z_out + x)
+        if self.fusion:
+            z_out = self.fusion_conv(z_out)
+
+        return z_out
+
+
+@HEADS.register_module()
+class APCHead(BaseDecodeHead):
+    """Adaptive Pyramid Context Network for Semantic Segmentation.
+
+    This head is the implementation of
+    `APCNet <https://openaccess.thecvf.com/content_CVPR_2019/papers/\
+    He_Adaptive_Pyramid_Context_Network_for_Semantic_Segmentation_\
+    CVPR_2019_paper.pdf>`_.
+
+    Args:
+        pool_scales (tuple[int]): Pooling scales used in Adaptive Context
+            Module. Default: (1, 2, 3, 6).
+        fusion (bool): Add one conv to fuse residual feature.
+    """
+
+    def __init__(self, pool_scales=(1, 2, 3, 6), fusion=True, **kwargs):
+        super(APCHead, self).__init__(**kwargs)
+        assert isinstance(pool_scales, (list, tuple))
+        self.pool_scales = pool_scales
+        self.fusion = fusion
+        acm_modules = []
+        for pool_scale in self.pool_scales:
+            acm_modules.append(
+                ACM(pool_scale,
+                    self.fusion,
+                    self.in_channels,
+                    self.channels,
+                    conv_cfg=self.conv_cfg,
+                    norm_cfg=self.norm_cfg,
+                    act_cfg=self.act_cfg))
+        self.acm_modules = nn.ModuleList(acm_modules)
+        self.bottleneck = ConvModule(
+            self.in_channels + len(pool_scales) * self.channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        acm_outs = [x]
+        for acm_module in self.acm_modules:
+            acm_outs.append(acm_module(x))
+        acm_outs = torch.cat(acm_outs, dim=1)
+        output = self.bottleneck(acm_outs)
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py b/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa914b5bb25124d1ff199553d96713d6a80484c0
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/aspp_head.py
@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class ASPPModule(nn.ModuleList):
+    """Atrous Spatial Pyramid Pooling (ASPP) Module.
+
+    Args:
+        dilations (tuple[int]): Dilation rate of each layer.
+        in_channels (int): Input channels.
+        channels (int): Channels after modules, before conv_seg.
+        conv_cfg (dict|None): Config of conv layers.
+        norm_cfg (dict|None): Config of norm layers.
+        act_cfg (dict): Config of activation layers.
+    """
+
+    def __init__(self, dilations, in_channels, channels, conv_cfg, norm_cfg,
+                 act_cfg):
+        super(ASPPModule, self).__init__()
+        self.dilations = dilations
+        self.in_channels = in_channels
+        self.channels = channels
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        for dilation in dilations:
+            self.append(
+                ConvModule(
+                    self.in_channels,
+                    self.channels,
+                    1 if dilation == 1 else 3,
+                    dilation=dilation,
+                    padding=0 if dilation == 1 else dilation,
+                    conv_cfg=self.conv_cfg,
+                    norm_cfg=self.norm_cfg,
+                    act_cfg=self.act_cfg))
+
+    def forward(self, x):
+        """Forward function."""
+        aspp_outs = []
+        for aspp_module in self:
+            aspp_outs.append(aspp_module(x))
+
+        return aspp_outs
+
+
+@HEADS.register_module()
+class ASPPHead(BaseDecodeHead):
+    """Rethinking Atrous Convolution for Semantic Image Segmentation.
+
+    This head is the implementation of `DeepLabV3
+    <https://arxiv.org/abs/1706.05587>`_.
+
+    Args:
+        dilations (tuple[int]): Dilation rates for ASPP module.
+            Default: (1, 6, 12, 18).
+    """
+
+    def __init__(self, dilations=(1, 6, 12, 18), **kwargs):
+        super(ASPPHead, self).__init__(**kwargs)
+        assert isinstance(dilations, (list, tuple))
+        self.dilations = dilations
+        self.image_pool = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            ConvModule(
+                self.in_channels,
+                self.channels,
+                1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg))
+        self.aspp_modules = ASPPModule(
+            dilations,
+            self.in_channels,
+            self.channels,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.bottleneck = ConvModule(
+            (len(dilations) + 1) * self.channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        aspp_outs = [
+            resize(
+                self.image_pool(x),
+                size=x.size()[2:],
+                mode='bilinear',
+                align_corners=self.align_corners)
+        ]
+        aspp_outs.extend(self.aspp_modules(x))
+        aspp_outs = torch.cat(aspp_outs, dim=1)
+        output = self.bottleneck(aspp_outs)
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py b/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d02122ca0e68743b1bf7a893afae96042f23838c
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/cascade_decode_head.py
@@ -0,0 +1,57 @@
+from abc import ABCMeta, abstractmethod
+
+from .decode_head import BaseDecodeHead
+
+
+class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
+    """Base class for cascade decode head used in
+    :class:`CascadeEncoderDecoder."""
+
+    def __init__(self, *args, **kwargs):
+        super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs)
+
+    @abstractmethod
+    def forward(self, inputs, prev_output):
+        """Placeholder of forward function."""
+        pass
+
+    def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+                      train_cfg):
+        """Forward function for training.
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+            prev_output (Tensor): The output of previous decode head.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmseg/datasets/pipelines/formatting.py:Collect`.
+            gt_semantic_seg (Tensor): Semantic segmentation masks
+                used if the architecture supports semantic segmentation task.
+            train_cfg (dict): The training config.
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+        seg_logits = self.forward(inputs, prev_output)
+        losses = self.losses(seg_logits, gt_semantic_seg)
+
+        return losses
+
+    def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+        """Forward function for testing.
+
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+            prev_output (Tensor): The output of previous decode head.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmseg/datasets/pipelines/formatting.py:Collect`.
+            test_cfg (dict): The testing config.
+
+        Returns:
+            Tensor: Output segmentation map.
+        """
+        return self.forward(inputs, prev_output)
diff --git a/annotator/uniformer/mmseg/models/decode_heads/cc_head.py b/annotator/uniformer/mmseg/models/decode_heads/cc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b9abb4e747f92657f4220b29788539340986c00
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/cc_head.py
@@ -0,0 +1,42 @@
+import torch
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+try:
+    from annotator.uniformer.mmcv.ops import CrissCrossAttention
+except ModuleNotFoundError:
+    CrissCrossAttention = None
+
+
+@HEADS.register_module()
+class CCHead(FCNHead):
+    """CCNet: Criss-Cross Attention for Semantic Segmentation.
+
+    This head is the implementation of `CCNet
+    <https://arxiv.org/abs/1811.11721>`_.
+
+    Args:
+        recurrence (int): Number of recurrence of Criss Cross Attention
+            module. Default: 2.
+    """
+
+    def __init__(self, recurrence=2, **kwargs):
+        if CrissCrossAttention is None:
+            raise RuntimeError('Please install mmcv-full for '
+                               'CrissCrossAttention ops')
+        super(CCHead, self).__init__(num_convs=2, **kwargs)
+        self.recurrence = recurrence
+        self.cca = CrissCrossAttention(self.channels)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        output = self.convs[0](x)
+        for _ in range(self.recurrence):
+            output = self.cca(output)
+        output = self.convs[1](output)
+        if self.concat_input:
+            output = self.conv_cat(torch.cat([x, output], dim=1))
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/da_head.py b/annotator/uniformer/mmseg/models/decode_heads/da_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cd49fcfdc7c0a70f9485cc71843dcf3e0cb1774
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/da_head.py
@@ -0,0 +1,178 @@
+import torch
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, Scale
+from torch import nn
+
+from annotator.uniformer.mmseg.core import add_prefix
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .decode_head import BaseDecodeHead
+
+
+class PAM(_SelfAttentionBlock):
+    """Position Attention Module (PAM)
+
+    Args:
+        in_channels (int): Input channels of key/query feature.
+        channels (int): Output channels of key/query transform.
+    """
+
+    def __init__(self, in_channels, channels):
+        super(PAM, self).__init__(
+            key_in_channels=in_channels,
+            query_in_channels=in_channels,
+            channels=channels,
+            out_channels=in_channels,
+            share_key_query=False,
+            query_downsample=None,
+            key_downsample=None,
+            key_query_num_convs=1,
+            key_query_norm=False,
+            value_out_num_convs=1,
+            value_out_norm=False,
+            matmul_norm=False,
+            with_out=False,
+            conv_cfg=None,
+            norm_cfg=None,
+            act_cfg=None)
+
+        self.gamma = Scale(0)
+
+    def forward(self, x):
+        """Forward function."""
+        out = super(PAM, self).forward(x, x)
+
+        out = self.gamma(out) + x
+        return out
+
+
+class CAM(nn.Module):
+    """Channel Attention Module (CAM)"""
+
+    def __init__(self):
+        super(CAM, self).__init__()
+        self.gamma = Scale(0)
+
+    def forward(self, x):
+        """Forward function."""
+        batch_size, channels, height, width = x.size()
+        proj_query = x.view(batch_size, channels, -1)
+        proj_key = x.view(batch_size, channels, -1).permute(0, 2, 1)
+        energy = torch.bmm(proj_query, proj_key)
+        energy_new = torch.max(
+            energy, -1, keepdim=True)[0].expand_as(energy) - energy
+        attention = F.softmax(energy_new, dim=-1)
+        proj_value = x.view(batch_size, channels, -1)
+
+        out = torch.bmm(attention, proj_value)
+        out = out.view(batch_size, channels, height, width)
+
+        out = self.gamma(out) + x
+        return out
+
+
+@HEADS.register_module()
+class DAHead(BaseDecodeHead):
+    """Dual Attention Network for Scene Segmentation.
+
+    This head is the implementation of `DANet
+    <https://arxiv.org/abs/1809.02983>`_.
+
+    Args:
+        pam_channels (int): The channels of Position Attention Module(PAM).
+    """
+
+    def __init__(self, pam_channels, **kwargs):
+        super(DAHead, self).__init__(**kwargs)
+        self.pam_channels = pam_channels
+        self.pam_in_conv = ConvModule(
+            self.in_channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.pam = PAM(self.channels, pam_channels)
+        self.pam_out_conv = ConvModule(
+            self.channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.pam_conv_seg = nn.Conv2d(
+            self.channels, self.num_classes, kernel_size=1)
+
+        self.cam_in_conv = ConvModule(
+            self.in_channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.cam = CAM()
+        self.cam_out_conv = ConvModule(
+            self.channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.cam_conv_seg = nn.Conv2d(
+            self.channels, self.num_classes, kernel_size=1)
+
+    def pam_cls_seg(self, feat):
+        """PAM feature classification."""
+        if self.dropout is not None:
+            feat = self.dropout(feat)
+        output = self.pam_conv_seg(feat)
+        return output
+
+    def cam_cls_seg(self, feat):
+        """CAM feature classification."""
+        if self.dropout is not None:
+            feat = self.dropout(feat)
+        output = self.cam_conv_seg(feat)
+        return output
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        pam_feat = self.pam_in_conv(x)
+        pam_feat = self.pam(pam_feat)
+        pam_feat = self.pam_out_conv(pam_feat)
+        pam_out = self.pam_cls_seg(pam_feat)
+
+        cam_feat = self.cam_in_conv(x)
+        cam_feat = self.cam(cam_feat)
+        cam_feat = self.cam_out_conv(cam_feat)
+        cam_out = self.cam_cls_seg(cam_feat)
+
+        feat_sum = pam_feat + cam_feat
+        pam_cam_out = self.cls_seg(feat_sum)
+
+        return pam_cam_out, pam_out, cam_out
+
+    def forward_test(self, inputs, img_metas, test_cfg):
+        """Forward function for testing, only ``pam_cam`` is used."""
+        return self.forward(inputs)[0]
+
+    def losses(self, seg_logit, seg_label):
+        """Compute ``pam_cam``, ``pam``, ``cam`` loss."""
+        pam_cam_seg_logit, pam_seg_logit, cam_seg_logit = seg_logit
+        loss = dict()
+        loss.update(
+            add_prefix(
+                super(DAHead, self).losses(pam_cam_seg_logit, seg_label),
+                'pam_cam'))
+        loss.update(
+            add_prefix(
+                super(DAHead, self).losses(pam_seg_logit, seg_label), 'pam'))
+        loss.update(
+            add_prefix(
+                super(DAHead, self).losses(cam_seg_logit, seg_label), 'cam'))
+        return loss
diff --git a/annotator/uniformer/mmseg/models/decode_heads/decode_head.py b/annotator/uniformer/mmseg/models/decode_heads/decode_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a661b8f6fec5d4c031d3d85e80777ee63951a6
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/decode_head.py
@@ -0,0 +1,234 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import normal_init
+from annotator.uniformer.mmcv.runner import auto_fp16, force_fp32
+
+from annotator.uniformer.mmseg.core import build_pixel_sampler
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import build_loss
+from ..losses import accuracy
+
+
+class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
+    """Base class for BaseDecodeHead.
+
+    Args:
+        in_channels (int|Sequence[int]): Input channels.
+        channels (int): Channels after modules, before conv_seg.
+        num_classes (int): Number of classes.
+        dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
+        conv_cfg (dict|None): Config of conv layers. Default: None.
+        norm_cfg (dict|None): Config of norm layers. Default: None.
+        act_cfg (dict): Config of activation layers.
+            Default: dict(type='ReLU')
+        in_index (int|Sequence[int]): Input feature index. Default: -1
+        input_transform (str|None): Transformation type of input features.
+            Options: 'resize_concat', 'multiple_select', None.
+            'resize_concat': Multiple feature maps will be resize to the
+                same size as first one and than concat together.
+                Usually used in FCN head of HRNet.
+            'multiple_select': Multiple feature maps will be bundle into
+                a list and passed into decode head.
+            None: Only one select feature map is allowed.
+            Default: None.
+        loss_decode (dict): Config of decode loss.
+            Default: dict(type='CrossEntropyLoss').
+        ignore_index (int | None): The label index to be ignored. When using
+            masked BCE loss, ignore_index should be set to None. Default: 255
+        sampler (dict|None): The config of segmentation map sampler.
+            Default: None.
+        align_corners (bool): align_corners argument of F.interpolate.
+            Default: False.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 channels,
+                 *,
+                 num_classes,
+                 dropout_ratio=0.1,
+                 conv_cfg=None,
+                 norm_cfg=None,
+                 act_cfg=dict(type='ReLU'),
+                 in_index=-1,
+                 input_transform=None,
+                 loss_decode=dict(
+                     type='CrossEntropyLoss',
+                     use_sigmoid=False,
+                     loss_weight=1.0),
+                 ignore_index=255,
+                 sampler=None,
+                 align_corners=False):
+        super(BaseDecodeHead, self).__init__()
+        self._init_inputs(in_channels, in_index, input_transform)
+        self.channels = channels
+        self.num_classes = num_classes
+        self.dropout_ratio = dropout_ratio
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        self.in_index = in_index
+        self.loss_decode = build_loss(loss_decode)
+        self.ignore_index = ignore_index
+        self.align_corners = align_corners
+        if sampler is not None:
+            self.sampler = build_pixel_sampler(sampler, context=self)
+        else:
+            self.sampler = None
+
+        self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
+        if dropout_ratio > 0:
+            self.dropout = nn.Dropout2d(dropout_ratio)
+        else:
+            self.dropout = None
+        self.fp16_enabled = False
+
+    def extra_repr(self):
+        """Extra repr."""
+        s = f'input_transform={self.input_transform}, ' \
+            f'ignore_index={self.ignore_index}, ' \
+            f'align_corners={self.align_corners}'
+        return s
+
+    def _init_inputs(self, in_channels, in_index, input_transform):
+        """Check and initialize input transforms.
+
+        The in_channels, in_index and input_transform must match.
+        Specifically, when input_transform is None, only single feature map
+        will be selected. So in_channels and in_index must be of type int.
+        When input_transform
+
+        Args:
+            in_channels (int|Sequence[int]): Input channels.
+            in_index (int|Sequence[int]): Input feature index.
+            input_transform (str|None): Transformation type of input features.
+                Options: 'resize_concat', 'multiple_select', None.
+                'resize_concat': Multiple feature maps will be resize to the
+                    same size as first one and than concat together.
+                    Usually used in FCN head of HRNet.
+                'multiple_select': Multiple feature maps will be bundle into
+                    a list and passed into decode head.
+                None: Only one select feature map is allowed.
+        """
+
+        if input_transform is not None:
+            assert input_transform in ['resize_concat', 'multiple_select']
+        self.input_transform = input_transform
+        self.in_index = in_index
+        if input_transform is not None:
+            assert isinstance(in_channels, (list, tuple))
+            assert isinstance(in_index, (list, tuple))
+            assert len(in_channels) == len(in_index)
+            if input_transform == 'resize_concat':
+                self.in_channels = sum(in_channels)
+            else:
+                self.in_channels = in_channels
+        else:
+            assert isinstance(in_channels, int)
+            assert isinstance(in_index, int)
+            self.in_channels = in_channels
+
+    def init_weights(self):
+        """Initialize weights of classification layer."""
+        normal_init(self.conv_seg, mean=0, std=0.01)
+
+    def _transform_inputs(self, inputs):
+        """Transform inputs for decoder.
+
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+
+        Returns:
+            Tensor: The transformed inputs
+        """
+
+        if self.input_transform == 'resize_concat':
+            inputs = [inputs[i] for i in self.in_index]
+            upsampled_inputs = [
+                resize(
+                    input=x,
+                    size=inputs[0].shape[2:],
+                    mode='bilinear',
+                    align_corners=self.align_corners) for x in inputs
+            ]
+            inputs = torch.cat(upsampled_inputs, dim=1)
+        elif self.input_transform == 'multiple_select':
+            inputs = [inputs[i] for i in self.in_index]
+        else:
+            inputs = inputs[self.in_index]
+
+        return inputs
+
+    @auto_fp16()
+    @abstractmethod
+    def forward(self, inputs):
+        """Placeholder of forward function."""
+        pass
+
+    def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
+        """Forward function for training.
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmseg/datasets/pipelines/formatting.py:Collect`.
+            gt_semantic_seg (Tensor): Semantic segmentation masks
+                used if the architecture supports semantic segmentation task.
+            train_cfg (dict): The training config.
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+        seg_logits = self.forward(inputs)
+        losses = self.losses(seg_logits, gt_semantic_seg)
+        return losses
+
+    def forward_test(self, inputs, img_metas, test_cfg):
+        """Forward function for testing.
+
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmseg/datasets/pipelines/formatting.py:Collect`.
+            test_cfg (dict): The testing config.
+
+        Returns:
+            Tensor: Output segmentation map.
+        """
+        return self.forward(inputs)
+
+    def cls_seg(self, feat):
+        """Classify each pixel."""
+        if self.dropout is not None:
+            feat = self.dropout(feat)
+        output = self.conv_seg(feat)
+        return output
+
+    @force_fp32(apply_to=('seg_logit', ))
+    def losses(self, seg_logit, seg_label):
+        """Compute segmentation loss."""
+        loss = dict()
+        seg_logit = resize(
+            input=seg_logit,
+            size=seg_label.shape[2:],
+            mode='bilinear',
+            align_corners=self.align_corners)
+        if self.sampler is not None:
+            seg_weight = self.sampler.sample(seg_logit, seg_label)
+        else:
+            seg_weight = None
+        seg_label = seg_label.squeeze(1)
+        loss['loss_seg'] = self.loss_decode(
+            seg_logit,
+            seg_label,
+            weight=seg_weight,
+            ignore_index=self.ignore_index)
+        loss['acc_seg'] = accuracy(seg_logit, seg_label)
+        return loss
diff --git a/annotator/uniformer/mmseg/models/decode_heads/dm_head.py b/annotator/uniformer/mmseg/models/decode_heads/dm_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c963923126b53ce22f60813540a35badf24b3d
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/dm_head.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class DCM(nn.Module):
+    """Dynamic Convolutional Module used in DMNet.
+
+    Args:
+        filter_size (int): The filter size of generated convolution kernel
+            used in Dynamic Convolutional Module.
+        fusion (bool): Add one conv to fuse DCM output feature.
+        in_channels (int): Input channels.
+        channels (int): Channels after modules, before conv_seg.
+        conv_cfg (dict | None): Config of conv layers.
+        norm_cfg (dict | None): Config of norm layers.
+        act_cfg (dict): Config of activation layers.
+    """
+
+    def __init__(self, filter_size, fusion, in_channels, channels, conv_cfg,
+                 norm_cfg, act_cfg):
+        super(DCM, self).__init__()
+        self.filter_size = filter_size
+        self.fusion = fusion
+        self.in_channels = in_channels
+        self.channels = channels
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        self.filter_gen_conv = nn.Conv2d(self.in_channels, self.channels, 1, 1,
+                                         0)
+
+        self.input_redu_conv = ConvModule(
+            self.in_channels,
+            self.channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+        if self.norm_cfg is not None:
+            self.norm = build_norm_layer(self.norm_cfg, self.channels)[1]
+        else:
+            self.norm = None
+        self.activate = build_activation_layer(self.act_cfg)
+
+        if self.fusion:
+            self.fusion_conv = ConvModule(
+                self.channels,
+                self.channels,
+                1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg)
+
+    def forward(self, x):
+        """Forward function."""
+        generated_filter = self.filter_gen_conv(
+            F.adaptive_avg_pool2d(x, self.filter_size))
+        x = self.input_redu_conv(x)
+        b, c, h, w = x.shape
+        # [1, b * c, h, w], c = self.channels
+        x = x.view(1, b * c, h, w)
+        # [b * c, 1, filter_size, filter_size]
+        generated_filter = generated_filter.view(b * c, 1, self.filter_size,
+                                                 self.filter_size)
+        pad = (self.filter_size - 1) // 2
+        if (self.filter_size - 1) % 2 == 0:
+            p2d = (pad, pad, pad, pad)
+        else:
+            p2d = (pad + 1, pad, pad + 1, pad)
+        x = F.pad(input=x, pad=p2d, mode='constant', value=0)
+        # [1, b * c, h, w]
+        output = F.conv2d(input=x, weight=generated_filter, groups=b * c)
+        # [b, c, h, w]
+        output = output.view(b, c, h, w)
+        if self.norm is not None:
+            output = self.norm(output)
+        output = self.activate(output)
+
+        if self.fusion:
+            output = self.fusion_conv(output)
+
+        return output
+
+
+@HEADS.register_module()
+class DMHead(BaseDecodeHead):
+    """Dynamic Multi-scale Filters for Semantic Segmentation.
+
+    This head is the implementation of
+    `DMNet <https://openaccess.thecvf.com/content_ICCV_2019/papers/\
+        He_Dynamic_Multi-Scale_Filters_for_Semantic_Segmentation_\
+            ICCV_2019_paper.pdf>`_.
+
+    Args:
+        filter_sizes (tuple[int]): The size of generated convolutional filters
+            used in Dynamic Convolutional Module. Default: (1, 3, 5, 7).
+        fusion (bool): Add one conv to fuse DCM output feature.
+    """
+
+    def __init__(self, filter_sizes=(1, 3, 5, 7), fusion=False, **kwargs):
+        super(DMHead, self).__init__(**kwargs)
+        assert isinstance(filter_sizes, (list, tuple))
+        self.filter_sizes = filter_sizes
+        self.fusion = fusion
+        dcm_modules = []
+        for filter_size in self.filter_sizes:
+            dcm_modules.append(
+                DCM(filter_size,
+                    self.fusion,
+                    self.in_channels,
+                    self.channels,
+                    conv_cfg=self.conv_cfg,
+                    norm_cfg=self.norm_cfg,
+                    act_cfg=self.act_cfg))
+        self.dcm_modules = nn.ModuleList(dcm_modules)
+        self.bottleneck = ConvModule(
+            self.in_channels + len(filter_sizes) * self.channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        dcm_outs = [x]
+        for dcm_module in self.dcm_modules:
+            dcm_outs.append(dcm_module(x))
+        dcm_outs = torch.cat(dcm_outs, dim=1)
+        output = self.bottleneck(dcm_outs)
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py b/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..333280c5947066fd3c7ebcfe302a0e7ad65480d5
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/dnl_head.py
@@ -0,0 +1,131 @@
+import torch
+from annotator.uniformer.mmcv.cnn import NonLocal2d
+from torch import nn
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+class DisentangledNonLocal2d(NonLocal2d):
+    """Disentangled Non-Local Blocks.
+
+    Args:
+        temperature (float): Temperature to adjust attention. Default: 0.05
+    """
+
+    def __init__(self, *arg, temperature, **kwargs):
+        super().__init__(*arg, **kwargs)
+        self.temperature = temperature
+        self.conv_mask = nn.Conv2d(self.in_channels, 1, kernel_size=1)
+
+    def embedded_gaussian(self, theta_x, phi_x):
+        """Embedded gaussian with temperature."""
+
+        # NonLocal2d pairwise_weight: [N, HxW, HxW]
+        pairwise_weight = torch.matmul(theta_x, phi_x)
+        if self.use_scale:
+            # theta_x.shape[-1] is `self.inter_channels`
+            pairwise_weight /= theta_x.shape[-1]**0.5
+        pairwise_weight /= self.temperature
+        pairwise_weight = pairwise_weight.softmax(dim=-1)
+        return pairwise_weight
+
+    def forward(self, x):
+        # x: [N, C, H, W]
+        n = x.size(0)
+
+        # g_x: [N, HxW, C]
+        g_x = self.g(x).view(n, self.inter_channels, -1)
+        g_x = g_x.permute(0, 2, 1)
+
+        # theta_x: [N, HxW, C], phi_x: [N, C, HxW]
+        if self.mode == 'gaussian':
+            theta_x = x.view(n, self.in_channels, -1)
+            theta_x = theta_x.permute(0, 2, 1)
+            if self.sub_sample:
+                phi_x = self.phi(x).view(n, self.in_channels, -1)
+            else:
+                phi_x = x.view(n, self.in_channels, -1)
+        elif self.mode == 'concatenation':
+            theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
+            phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
+        else:
+            theta_x = self.theta(x).view(n, self.inter_channels, -1)
+            theta_x = theta_x.permute(0, 2, 1)
+            phi_x = self.phi(x).view(n, self.inter_channels, -1)
+
+        # subtract mean
+        theta_x -= theta_x.mean(dim=-2, keepdim=True)
+        phi_x -= phi_x.mean(dim=-1, keepdim=True)
+
+        pairwise_func = getattr(self, self.mode)
+        # pairwise_weight: [N, HxW, HxW]
+        pairwise_weight = pairwise_func(theta_x, phi_x)
+
+        # y: [N, HxW, C]
+        y = torch.matmul(pairwise_weight, g_x)
+        # y: [N, C, H, W]
+        y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
+                                                    *x.size()[2:])
+
+        # unary_mask: [N, 1, HxW]
+        unary_mask = self.conv_mask(x)
+        unary_mask = unary_mask.view(n, 1, -1)
+        unary_mask = unary_mask.softmax(dim=-1)
+        # unary_x: [N, 1, C]
+        unary_x = torch.matmul(unary_mask, g_x)
+        # unary_x: [N, C, 1, 1]
+        unary_x = unary_x.permute(0, 2, 1).contiguous().reshape(
+            n, self.inter_channels, 1, 1)
+
+        output = x + self.conv_out(y + unary_x)
+
+        return output
+
+
+@HEADS.register_module()
+class DNLHead(FCNHead):
+    """Disentangled Non-Local Neural Networks.
+
+    This head is the implementation of `DNLNet
+    <https://arxiv.org/abs/2006.06668>`_.
+
+    Args:
+        reduction (int): Reduction factor of projection transform. Default: 2.
+        use_scale (bool): Whether to scale pairwise_weight by
+            sqrt(1/inter_channels). Default: False.
+        mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+            'dot_product'. Default: 'embedded_gaussian.'.
+        temperature (float): Temperature to adjust attention. Default: 0.05
+    """
+
+    def __init__(self,
+                 reduction=2,
+                 use_scale=True,
+                 mode='embedded_gaussian',
+                 temperature=0.05,
+                 **kwargs):
+        super(DNLHead, self).__init__(num_convs=2, **kwargs)
+        self.reduction = reduction
+        self.use_scale = use_scale
+        self.mode = mode
+        self.temperature = temperature
+        self.dnl_block = DisentangledNonLocal2d(
+            in_channels=self.channels,
+            reduction=self.reduction,
+            use_scale=self.use_scale,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            mode=self.mode,
+            temperature=self.temperature)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        output = self.convs[0](x)
+        output = self.dnl_block(output)
+        output = self.convs[1](output)
+        if self.concat_input:
+            output = self.conv_cat(torch.cat([x, output], dim=1))
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/ema_head.py b/annotator/uniformer/mmseg/models/decode_heads/ema_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..12267cb40569d2b5a4a2955a6dc2671377ff5e0a
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/ema_head.py
@@ -0,0 +1,168 @@
+import math
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+def reduce_mean(tensor):
+    """Reduce mean when distributed training."""
+    if not (dist.is_available() and dist.is_initialized()):
+        return tensor
+    tensor = tensor.clone()
+    dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+    return tensor
+
+
+class EMAModule(nn.Module):
+    """Expectation Maximization Attention Module used in EMANet.
+
+    Args:
+        channels (int): Channels of the whole module.
+        num_bases (int): Number of bases.
+        num_stages (int): Number of the EM iterations.
+    """
+
+    def __init__(self, channels, num_bases, num_stages, momentum):
+        super(EMAModule, self).__init__()
+        assert num_stages >= 1, 'num_stages must be at least 1!'
+        self.num_bases = num_bases
+        self.num_stages = num_stages
+        self.momentum = momentum
+
+        bases = torch.zeros(1, channels, self.num_bases)
+        bases.normal_(0, math.sqrt(2. / self.num_bases))
+        # [1, channels, num_bases]
+        bases = F.normalize(bases, dim=1, p=2)
+        self.register_buffer('bases', bases)
+
+    def forward(self, feats):
+        """Forward function."""
+        batch_size, channels, height, width = feats.size()
+        # [batch_size, channels, height*width]
+        feats = feats.view(batch_size, channels, height * width)
+        # [batch_size, channels, num_bases]
+        bases = self.bases.repeat(batch_size, 1, 1)
+
+        with torch.no_grad():
+            for i in range(self.num_stages):
+                # [batch_size, height*width, num_bases]
+                attention = torch.einsum('bcn,bck->bnk', feats, bases)
+                attention = F.softmax(attention, dim=2)
+                # l1 norm
+                attention_normed = F.normalize(attention, dim=1, p=1)
+                # [batch_size, channels, num_bases]
+                bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
+                # l2 norm
+                bases = F.normalize(bases, dim=1, p=2)
+
+        feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
+        feats_recon = feats_recon.view(batch_size, channels, height, width)
+
+        if self.training:
+            bases = bases.mean(dim=0, keepdim=True)
+            bases = reduce_mean(bases)
+            # l2 norm
+            bases = F.normalize(bases, dim=1, p=2)
+            self.bases = (1 -
+                          self.momentum) * self.bases + self.momentum * bases
+
+        return feats_recon
+
+
+@HEADS.register_module()
+class EMAHead(BaseDecodeHead):
+    """Expectation Maximization Attention Networks for Semantic Segmentation.
+
+    This head is the implementation of `EMANet
+    <https://arxiv.org/abs/1907.13426>`_.
+
+    Args:
+        ema_channels (int): EMA module channels
+        num_bases (int): Number of bases.
+        num_stages (int): Number of the EM iterations.
+        concat_input (bool): Whether concat the input and output of convs
+            before classification layer. Default: True
+        momentum (float): Momentum to update the base. Default: 0.1.
+    """
+
+    def __init__(self,
+                 ema_channels,
+                 num_bases,
+                 num_stages,
+                 concat_input=True,
+                 momentum=0.1,
+                 **kwargs):
+        super(EMAHead, self).__init__(**kwargs)
+        self.ema_channels = ema_channels
+        self.num_bases = num_bases
+        self.num_stages = num_stages
+        self.concat_input = concat_input
+        self.momentum = momentum
+        self.ema_module = EMAModule(self.ema_channels, self.num_bases,
+                                    self.num_stages, self.momentum)
+
+        self.ema_in_conv = ConvModule(
+            self.in_channels,
+            self.ema_channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        # project (0, inf) -> (-inf, inf)
+        self.ema_mid_conv = ConvModule(
+            self.ema_channels,
+            self.ema_channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=None,
+            act_cfg=None)
+        for param in self.ema_mid_conv.parameters():
+            param.requires_grad = False
+
+        self.ema_out_conv = ConvModule(
+            self.ema_channels,
+            self.ema_channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=None)
+        self.bottleneck = ConvModule(
+            self.ema_channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        if self.concat_input:
+            self.conv_cat = ConvModule(
+                self.in_channels + self.channels,
+                self.channels,
+                kernel_size=3,
+                padding=1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        feats = self.ema_in_conv(x)
+        identity = feats
+        feats = self.ema_mid_conv(feats)
+        recon = self.ema_module(feats)
+        recon = F.relu(recon, inplace=True)
+        recon = self.ema_out_conv(recon)
+        output = F.relu(identity + recon, inplace=True)
+        output = self.bottleneck(output)
+        if self.concat_input:
+            output = self.conv_cat(torch.cat([x, output], dim=1))
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/enc_head.py b/annotator/uniformer/mmseg/models/decode_heads/enc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..da57af617e05d41761628fd2d6d232655b32d905
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/enc_head.py
@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, build_norm_layer
+
+from annotator.uniformer.mmseg.ops import Encoding, resize
+from ..builder import HEADS, build_loss
+from .decode_head import BaseDecodeHead
+
+
+class EncModule(nn.Module):
+    """Encoding Module used in EncNet.
+
+    Args:
+        in_channels (int): Input channels.
+        num_codes (int): Number of code words.
+        conv_cfg (dict|None): Config of conv layers.
+        norm_cfg (dict|None): Config of norm layers.
+        act_cfg (dict): Config of activation layers.
+    """
+
+    def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
+        super(EncModule, self).__init__()
+        self.encoding_project = ConvModule(
+            in_channels,
+            in_channels,
+            1,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg)
+        # TODO: resolve this hack
+        # change to 1d
+        if norm_cfg is not None:
+            encoding_norm_cfg = norm_cfg.copy()
+            if encoding_norm_cfg['type'] in ['BN', 'IN']:
+                encoding_norm_cfg['type'] += '1d'
+            else:
+                encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
+                    '2d', '1d')
+        else:
+            # fallback to BN1d
+            encoding_norm_cfg = dict(type='BN1d')
+        self.encoding = nn.Sequential(
+            Encoding(channels=in_channels, num_codes=num_codes),
+            build_norm_layer(encoding_norm_cfg, num_codes)[1],
+            nn.ReLU(inplace=True))
+        self.fc = nn.Sequential(
+            nn.Linear(in_channels, in_channels), nn.Sigmoid())
+
+    def forward(self, x):
+        """Forward function."""
+        encoding_projection = self.encoding_project(x)
+        encoding_feat = self.encoding(encoding_projection).mean(dim=1)
+        batch_size, channels, _, _ = x.size()
+        gamma = self.fc(encoding_feat)
+        y = gamma.view(batch_size, channels, 1, 1)
+        output = F.relu_(x + x * y)
+        return encoding_feat, output
+
+
+@HEADS.register_module()
+class EncHead(BaseDecodeHead):
+    """Context Encoding for Semantic Segmentation.
+
+    This head is the implementation of `EncNet
+    <https://arxiv.org/abs/1803.08904>`_.
+
+    Args:
+        num_codes (int): Number of code words. Default: 32.
+        use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
+            regularize the training. Default: True.
+        add_lateral (bool): Whether use lateral connection to fuse features.
+            Default: False.
+        loss_se_decode (dict): Config of decode loss.
+            Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
+    """
+
+    def __init__(self,
+                 num_codes=32,
+                 use_se_loss=True,
+                 add_lateral=False,
+                 loss_se_decode=dict(
+                     type='CrossEntropyLoss',
+                     use_sigmoid=True,
+                     loss_weight=0.2),
+                 **kwargs):
+        super(EncHead, self).__init__(
+            input_transform='multiple_select', **kwargs)
+        self.use_se_loss = use_se_loss
+        self.add_lateral = add_lateral
+        self.num_codes = num_codes
+        self.bottleneck = ConvModule(
+            self.in_channels[-1],
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        if add_lateral:
+            self.lateral_convs = nn.ModuleList()
+            for in_channels in self.in_channels[:-1]:  # skip the last one
+                self.lateral_convs.append(
+                    ConvModule(
+                        in_channels,
+                        self.channels,
+                        1,
+                        conv_cfg=self.conv_cfg,
+                        norm_cfg=self.norm_cfg,
+                        act_cfg=self.act_cfg))
+            self.fusion = ConvModule(
+                len(self.in_channels) * self.channels,
+                self.channels,
+                3,
+                padding=1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg)
+        self.enc_module = EncModule(
+            self.channels,
+            num_codes=num_codes,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        if self.use_se_loss:
+            self.loss_se_decode = build_loss(loss_se_decode)
+            self.se_layer = nn.Linear(self.channels, self.num_classes)
+
+    def forward(self, inputs):
+        """Forward function."""
+        inputs = self._transform_inputs(inputs)
+        feat = self.bottleneck(inputs[-1])
+        if self.add_lateral:
+            laterals = [
+                resize(
+                    lateral_conv(inputs[i]),
+                    size=feat.shape[2:],
+                    mode='bilinear',
+                    align_corners=self.align_corners)
+                for i, lateral_conv in enumerate(self.lateral_convs)
+            ]
+            feat = self.fusion(torch.cat([feat, *laterals], 1))
+        encode_feat, output = self.enc_module(feat)
+        output = self.cls_seg(output)
+        if self.use_se_loss:
+            se_output = self.se_layer(encode_feat)
+            return output, se_output
+        else:
+            return output
+
+    def forward_test(self, inputs, img_metas, test_cfg):
+        """Forward function for testing, ignore se_loss."""
+        if self.use_se_loss:
+            return self.forward(inputs)[0]
+        else:
+            return self.forward(inputs)
+
+    @staticmethod
+    def _convert_to_onehot_labels(seg_label, num_classes):
+        """Convert segmentation label to onehot.
+
+        Args:
+            seg_label (Tensor): Segmentation label of shape (N, H, W).
+            num_classes (int): Number of classes.
+
+        Returns:
+            Tensor: Onehot labels of shape (N, num_classes).
+        """
+
+        batch_size = seg_label.size(0)
+        onehot_labels = seg_label.new_zeros((batch_size, num_classes))
+        for i in range(batch_size):
+            hist = seg_label[i].float().histc(
+                bins=num_classes, min=0, max=num_classes - 1)
+            onehot_labels[i] = hist > 0
+        return onehot_labels
+
+    def losses(self, seg_logit, seg_label):
+        """Compute segmentation and semantic encoding loss."""
+        seg_logit, se_seg_logit = seg_logit
+        loss = dict()
+        loss.update(super(EncHead, self).losses(seg_logit, seg_label))
+        se_loss = self.loss_se_decode(
+            se_seg_logit,
+            self._convert_to_onehot_labels(seg_label, self.num_classes))
+        loss['loss_se'] = se_loss
+        return loss
diff --git a/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py b/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..edb32c283fa4baada6b4a0bf3f7540c3580c3468
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/fcn_head.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class FCNHead(BaseDecodeHead):
+    """Fully Convolution Networks for Semantic Segmentation.
+
+    This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_.
+
+    Args:
+        num_convs (int): Number of convs in the head. Default: 2.
+        kernel_size (int): The kernel size for convs in the head. Default: 3.
+        concat_input (bool): Whether concat the input and output of convs
+            before classification layer.
+        dilation (int): The dilation rate for convs in the head. Default: 1.
+    """
+
+    def __init__(self,
+                 num_convs=2,
+                 kernel_size=3,
+                 concat_input=True,
+                 dilation=1,
+                 **kwargs):
+        assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int)
+        self.num_convs = num_convs
+        self.concat_input = concat_input
+        self.kernel_size = kernel_size
+        super(FCNHead, self).__init__(**kwargs)
+        if num_convs == 0:
+            assert self.in_channels == self.channels
+
+        conv_padding = (kernel_size // 2) * dilation
+        convs = []
+        convs.append(
+            ConvModule(
+                self.in_channels,
+                self.channels,
+                kernel_size=kernel_size,
+                padding=conv_padding,
+                dilation=dilation,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg))
+        for i in range(num_convs - 1):
+            convs.append(
+                ConvModule(
+                    self.channels,
+                    self.channels,
+                    kernel_size=kernel_size,
+                    padding=conv_padding,
+                    dilation=dilation,
+                    conv_cfg=self.conv_cfg,
+                    norm_cfg=self.norm_cfg,
+                    act_cfg=self.act_cfg))
+        if num_convs == 0:
+            self.convs = nn.Identity()
+        else:
+            self.convs = nn.Sequential(*convs)
+        if self.concat_input:
+            self.conv_cat = ConvModule(
+                self.in_channels + self.channels,
+                self.channels,
+                kernel_size=kernel_size,
+                padding=kernel_size // 2,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        output = self.convs(x)
+        if self.concat_input:
+            output = self.conv_cat(torch.cat([x, output], dim=1))
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py b/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1241c55b0813d1ecdddf1e66e7c5031fbf78ed50
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/fpn_head.py
@@ -0,0 +1,68 @@
+import numpy as np
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class FPNHead(BaseDecodeHead):
+    """Panoptic Feature Pyramid Networks.
+
+    This head is the implementation of `Semantic FPN
+    <https://arxiv.org/abs/1901.02446>`_.
+
+    Args:
+        feature_strides (tuple[int]): The strides for input feature maps.
+            stack_lateral. All strides suppose to be power of 2. The first
+            one is of largest resolution.
+    """
+
+    def __init__(self, feature_strides, **kwargs):
+        super(FPNHead, self).__init__(
+            input_transform='multiple_select', **kwargs)
+        assert len(feature_strides) == len(self.in_channels)
+        assert min(feature_strides) == feature_strides[0]
+        self.feature_strides = feature_strides
+
+        self.scale_heads = nn.ModuleList()
+        for i in range(len(feature_strides)):
+            head_length = max(
+                1,
+                int(np.log2(feature_strides[i]) - np.log2(feature_strides[0])))
+            scale_head = []
+            for k in range(head_length):
+                scale_head.append(
+                    ConvModule(
+                        self.in_channels[i] if k == 0 else self.channels,
+                        self.channels,
+                        3,
+                        padding=1,
+                        conv_cfg=self.conv_cfg,
+                        norm_cfg=self.norm_cfg,
+                        act_cfg=self.act_cfg))
+                if feature_strides[i] != feature_strides[0]:
+                    scale_head.append(
+                        nn.Upsample(
+                            scale_factor=2,
+                            mode='bilinear',
+                            align_corners=self.align_corners))
+            self.scale_heads.append(nn.Sequential(*scale_head))
+
+    def forward(self, inputs):
+
+        x = self._transform_inputs(inputs)
+
+        output = self.scale_heads[0](x[0])
+        for i in range(1, len(self.feature_strides)):
+            # non inplace
+            output = output + resize(
+                self.scale_heads[i](x[i]),
+                size=output.shape[2:],
+                mode='bilinear',
+                align_corners=self.align_corners)
+
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/gc_head.py b/annotator/uniformer/mmseg/models/decode_heads/gc_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..70741245af975800840709911bd18d72247e3e04
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/gc_head.py
@@ -0,0 +1,47 @@
+import torch
+from annotator.uniformer.mmcv.cnn import ContextBlock
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class GCHead(FCNHead):
+    """GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond.
+
+    This head is the implementation of `GCNet
+    <https://arxiv.org/abs/1904.11492>`_.
+
+    Args:
+        ratio (float): Multiplier of channels ratio. Default: 1/4.
+        pooling_type (str): The pooling type of context aggregation.
+            Options are 'att', 'avg'. Default: 'avg'.
+        fusion_types (tuple[str]): The fusion type for feature fusion.
+            Options are 'channel_add', 'channel_mul'. Default: ('channel_add',)
+    """
+
+    def __init__(self,
+                 ratio=1 / 4.,
+                 pooling_type='att',
+                 fusion_types=('channel_add', ),
+                 **kwargs):
+        super(GCHead, self).__init__(num_convs=2, **kwargs)
+        self.ratio = ratio
+        self.pooling_type = pooling_type
+        self.fusion_types = fusion_types
+        self.gc_block = ContextBlock(
+            in_channels=self.channels,
+            ratio=self.ratio,
+            pooling_type=self.pooling_type,
+            fusion_types=self.fusion_types)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        output = self.convs[0](x)
+        output = self.gc_block(output)
+        output = self.convs[1](output)
+        if self.concat_input:
+            output = self.conv_cat(torch.cat([x, output], dim=1))
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py b/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..69bf320934d787aaa11984a0c4effe9ad8015b22
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/lraspp_head.py
@@ -0,0 +1,90 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv import is_tuple_of
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+@HEADS.register_module()
+class LRASPPHead(BaseDecodeHead):
+    """Lite R-ASPP (LRASPP) head is proposed in Searching for MobileNetV3.
+
+    This head is the improved implementation of `Searching for MobileNetV3
+    <https://ieeexplore.ieee.org/document/9008835>`_.
+
+    Args:
+        branch_channels (tuple[int]): The number of output channels in every
+            each branch. Default: (32, 64).
+    """
+
+    def __init__(self, branch_channels=(32, 64), **kwargs):
+        super(LRASPPHead, self).__init__(**kwargs)
+        if self.input_transform != 'multiple_select':
+            raise ValueError('in Lite R-ASPP (LRASPP) head, input_transform '
+                             f'must be \'multiple_select\'. But received '
+                             f'\'{self.input_transform}\'')
+        assert is_tuple_of(branch_channels, int)
+        assert len(branch_channels) == len(self.in_channels) - 1
+        self.branch_channels = branch_channels
+
+        self.convs = nn.Sequential()
+        self.conv_ups = nn.Sequential()
+        for i in range(len(branch_channels)):
+            self.convs.add_module(
+                f'conv{i}',
+                nn.Conv2d(
+                    self.in_channels[i], branch_channels[i], 1, bias=False))
+            self.conv_ups.add_module(
+                f'conv_up{i}',
+                ConvModule(
+                    self.channels + branch_channels[i],
+                    self.channels,
+                    1,
+                    norm_cfg=self.norm_cfg,
+                    act_cfg=self.act_cfg,
+                    bias=False))
+
+        self.conv_up_input = nn.Conv2d(self.channels, self.channels, 1)
+
+        self.aspp_conv = ConvModule(
+            self.in_channels[-1],
+            self.channels,
+            1,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg,
+            bias=False)
+        self.image_pool = nn.Sequential(
+            nn.AvgPool2d(kernel_size=49, stride=(16, 20)),
+            ConvModule(
+                self.in_channels[2],
+                self.channels,
+                1,
+                act_cfg=dict(type='Sigmoid'),
+                bias=False))
+
+    def forward(self, inputs):
+        """Forward function."""
+        inputs = self._transform_inputs(inputs)
+
+        x = inputs[-1]
+
+        x = self.aspp_conv(x) * resize(
+            self.image_pool(x),
+            size=x.size()[2:],
+            mode='bilinear',
+            align_corners=self.align_corners)
+        x = self.conv_up_input(x)
+
+        for i in range(len(self.branch_channels) - 1, -1, -1):
+            x = resize(
+                x,
+                size=inputs[i].size()[2:],
+                mode='bilinear',
+                align_corners=self.align_corners)
+            x = torch.cat([x, self.convs[i](inputs[i])], 1)
+            x = self.conv_ups[i](x)
+
+        return self.cls_seg(x)
diff --git a/annotator/uniformer/mmseg/models/decode_heads/nl_head.py b/annotator/uniformer/mmseg/models/decode_heads/nl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eee424199e6aa363b564e2a3340a070db04db86
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/nl_head.py
@@ -0,0 +1,49 @@
+import torch
+from annotator.uniformer.mmcv.cnn import NonLocal2d
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class NLHead(FCNHead):
+    """Non-local Neural Networks.
+
+    This head is the implementation of `NLNet
+    <https://arxiv.org/abs/1711.07971>`_.
+
+    Args:
+        reduction (int): Reduction factor of projection transform. Default: 2.
+        use_scale (bool): Whether to scale pairwise_weight by
+            sqrt(1/inter_channels). Default: True.
+        mode (str): The nonlocal mode. Options are 'embedded_gaussian',
+            'dot_product'. Default: 'embedded_gaussian.'.
+    """
+
+    def __init__(self,
+                 reduction=2,
+                 use_scale=True,
+                 mode='embedded_gaussian',
+                 **kwargs):
+        super(NLHead, self).__init__(num_convs=2, **kwargs)
+        self.reduction = reduction
+        self.use_scale = use_scale
+        self.mode = mode
+        self.nl_block = NonLocal2d(
+            in_channels=self.channels,
+            reduction=self.reduction,
+            use_scale=self.use_scale,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            mode=self.mode)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        output = self.convs[0](x)
+        output = self.nl_block(output)
+        output = self.convs[1](output)
+        if self.concat_input:
+            output = self.conv_cat(torch.cat([x, output], dim=1))
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py b/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..715852e94e81dc46623972748285d2d19237a341
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/ocr_head.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from ..utils import SelfAttentionBlock as _SelfAttentionBlock
+from .cascade_decode_head import BaseCascadeDecodeHead
+
+
+class SpatialGatherModule(nn.Module):
+    """Aggregate the context features according to the initial predicted
+    probability distribution.
+
+    Employ the soft-weighted method to aggregate the context.
+    """
+
+    def __init__(self, scale):
+        super(SpatialGatherModule, self).__init__()
+        self.scale = scale
+
+    def forward(self, feats, probs):
+        """Forward function."""
+        batch_size, num_classes, height, width = probs.size()
+        channels = feats.size(1)
+        probs = probs.view(batch_size, num_classes, -1)
+        feats = feats.view(batch_size, channels, -1)
+        # [batch_size, height*width, num_classes]
+        feats = feats.permute(0, 2, 1)
+        # [batch_size, channels, height*width]
+        probs = F.softmax(self.scale * probs, dim=2)
+        # [batch_size, channels, num_classes]
+        ocr_context = torch.matmul(probs, feats)
+        ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)
+        return ocr_context
+
+
+class ObjectAttentionBlock(_SelfAttentionBlock):
+    """Make a OCR used SelfAttentionBlock."""
+
+    def __init__(self, in_channels, channels, scale, conv_cfg, norm_cfg,
+                 act_cfg):
+        if scale > 1:
+            query_downsample = nn.MaxPool2d(kernel_size=scale)
+        else:
+            query_downsample = None
+        super(ObjectAttentionBlock, self).__init__(
+            key_in_channels=in_channels,
+            query_in_channels=in_channels,
+            channels=channels,
+            out_channels=in_channels,
+            share_key_query=False,
+            query_downsample=query_downsample,
+            key_downsample=None,
+            key_query_num_convs=2,
+            key_query_norm=True,
+            value_out_num_convs=1,
+            value_out_norm=True,
+            matmul_norm=True,
+            with_out=True,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg)
+        self.bottleneck = ConvModule(
+            in_channels * 2,
+            in_channels,
+            1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def forward(self, query_feats, key_feats):
+        """Forward function."""
+        context = super(ObjectAttentionBlock,
+                        self).forward(query_feats, key_feats)
+        output = self.bottleneck(torch.cat([context, query_feats], dim=1))
+        if self.query_downsample is not None:
+            output = resize(query_feats)
+
+        return output
+
+
+@HEADS.register_module()
+class OCRHead(BaseCascadeDecodeHead):
+    """Object-Contextual Representations for Semantic Segmentation.
+
+    This head is the implementation of `OCRNet
+    <https://arxiv.org/abs/1909.11065>`_.
+
+    Args:
+        ocr_channels (int): The intermediate channels of OCR block.
+        scale (int): The scale of probability map in SpatialGatherModule in
+            Default: 1.
+    """
+
+    def __init__(self, ocr_channels, scale=1, **kwargs):
+        super(OCRHead, self).__init__(**kwargs)
+        self.ocr_channels = ocr_channels
+        self.scale = scale
+        self.object_context_block = ObjectAttentionBlock(
+            self.channels,
+            self.ocr_channels,
+            self.scale,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.spatial_gather_module = SpatialGatherModule(self.scale)
+
+        self.bottleneck = ConvModule(
+            self.in_channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def forward(self, inputs, prev_output):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        feats = self.bottleneck(x)
+        context = self.spatial_gather_module(feats, prev_output)
+        object_context = self.object_context_block(feats, context)
+        output = self.cls_seg(object_context)
+
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/point_head.py b/annotator/uniformer/mmseg/models/decode_heads/point_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3342aa28bb8d264b2c3d01cbf5098d145943c193
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/point_head.py
@@ -0,0 +1,349 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py  # noqa
+
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, normal_init
+from annotator.uniformer.mmcv.ops import point_sample
+
+from annotator.uniformer.mmseg.models.builder import HEADS
+from annotator.uniformer.mmseg.ops import resize
+from ..losses import accuracy
+from .cascade_decode_head import BaseCascadeDecodeHead
+
+
+def calculate_uncertainty(seg_logits):
+    """Estimate uncertainty based on seg logits.
+
+    For each location of the prediction ``seg_logits`` we estimate
+    uncertainty as the difference between top first and top second
+    predicted logits.
+
+    Args:
+        seg_logits (Tensor): Semantic segmentation logits,
+            shape (batch_size, num_classes, height, width).
+
+    Returns:
+        scores (Tensor): T uncertainty scores with the most uncertain
+            locations having the highest uncertainty score, shape (
+            batch_size, 1, height, width)
+    """
+    top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
+    return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
+
+
+@HEADS.register_module()
+class PointHead(BaseCascadeDecodeHead):
+    """A mask point head use in PointRend.
+
+    ``PointHead`` use shared multi-layer perceptron (equivalent to
+    nn.Conv1d) to predict the logit of input points. The fine-grained feature
+    and coarse feature will be concatenate together for predication.
+
+    Args:
+        num_fcs (int): Number of fc layers in the head. Default: 3.
+        in_channels (int): Number of input channels. Default: 256.
+        fc_channels (int): Number of fc channels. Default: 256.
+        num_classes (int): Number of classes for logits. Default: 80.
+        class_agnostic (bool): Whether use class agnostic classification.
+            If so, the output channels of logits will be 1. Default: False.
+        coarse_pred_each_layer (bool): Whether concatenate coarse feature with
+            the output of each fc layer. Default: True.
+        conv_cfg (dict|None): Dictionary to construct and config conv layer.
+            Default: dict(type='Conv1d'))
+        norm_cfg (dict|None): Dictionary to construct and config norm layer.
+            Default: None.
+        loss_point (dict): Dictionary to construct and config loss layer of
+            point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
+            loss_weight=1.0).
+    """
+
+    def __init__(self,
+                 num_fcs=3,
+                 coarse_pred_each_layer=True,
+                 conv_cfg=dict(type='Conv1d'),
+                 norm_cfg=None,
+                 act_cfg=dict(type='ReLU', inplace=False),
+                 **kwargs):
+        super(PointHead, self).__init__(
+            input_transform='multiple_select',
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg,
+            **kwargs)
+
+        self.num_fcs = num_fcs
+        self.coarse_pred_each_layer = coarse_pred_each_layer
+
+        fc_in_channels = sum(self.in_channels) + self.num_classes
+        fc_channels = self.channels
+        self.fcs = nn.ModuleList()
+        for k in range(num_fcs):
+            fc = ConvModule(
+                fc_in_channels,
+                fc_channels,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg)
+            self.fcs.append(fc)
+            fc_in_channels = fc_channels
+            fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
+                else 0
+        self.fc_seg = nn.Conv1d(
+            fc_in_channels,
+            self.num_classes,
+            kernel_size=1,
+            stride=1,
+            padding=0)
+        if self.dropout_ratio > 0:
+            self.dropout = nn.Dropout(self.dropout_ratio)
+        delattr(self, 'conv_seg')
+
+    def init_weights(self):
+        """Initialize weights of classification layer."""
+        normal_init(self.fc_seg, std=0.001)
+
+    def cls_seg(self, feat):
+        """Classify each pixel with fc."""
+        if self.dropout is not None:
+            feat = self.dropout(feat)
+        output = self.fc_seg(feat)
+        return output
+
+    def forward(self, fine_grained_point_feats, coarse_point_feats):
+        x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
+        for fc in self.fcs:
+            x = fc(x)
+            if self.coarse_pred_each_layer:
+                x = torch.cat((x, coarse_point_feats), dim=1)
+        return self.cls_seg(x)
+
+    def _get_fine_grained_point_feats(self, x, points):
+        """Sample from fine grained features.
+
+        Args:
+            x (list[Tensor]): Feature pyramid from by neck or backbone.
+            points (Tensor): Point coordinates, shape (batch_size,
+                num_points, 2).
+
+        Returns:
+            fine_grained_feats (Tensor): Sampled fine grained feature,
+                shape (batch_size, sum(channels of x), num_points).
+        """
+
+        fine_grained_feats_list = [
+            point_sample(_, points, align_corners=self.align_corners)
+            for _ in x
+        ]
+        if len(fine_grained_feats_list) > 1:
+            fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
+        else:
+            fine_grained_feats = fine_grained_feats_list[0]
+
+        return fine_grained_feats
+
+    def _get_coarse_point_feats(self, prev_output, points):
+        """Sample from fine grained features.
+
+        Args:
+            prev_output (list[Tensor]): Prediction of previous decode head.
+            points (Tensor): Point coordinates, shape (batch_size,
+                num_points, 2).
+
+        Returns:
+            coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
+                num_classes, num_points).
+        """
+
+        coarse_feats = point_sample(
+            prev_output, points, align_corners=self.align_corners)
+
+        return coarse_feats
+
+    def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg,
+                      train_cfg):
+        """Forward function for training.
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+            prev_output (Tensor): The output of previous decode head.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmseg/datasets/pipelines/formatting.py:Collect`.
+            gt_semantic_seg (Tensor): Semantic segmentation masks
+                used if the architecture supports semantic segmentation task.
+            train_cfg (dict): The training config.
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+        x = self._transform_inputs(inputs)
+        with torch.no_grad():
+            points = self.get_points_train(
+                prev_output, calculate_uncertainty, cfg=train_cfg)
+        fine_grained_point_feats = self._get_fine_grained_point_feats(
+            x, points)
+        coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
+        point_logits = self.forward(fine_grained_point_feats,
+                                    coarse_point_feats)
+        point_label = point_sample(
+            gt_semantic_seg.float(),
+            points,
+            mode='nearest',
+            align_corners=self.align_corners)
+        point_label = point_label.squeeze(1).long()
+
+        losses = self.losses(point_logits, point_label)
+
+        return losses
+
+    def forward_test(self, inputs, prev_output, img_metas, test_cfg):
+        """Forward function for testing.
+
+        Args:
+            inputs (list[Tensor]): List of multi-level img features.
+            prev_output (Tensor): The output of previous decode head.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmseg/datasets/pipelines/formatting.py:Collect`.
+            test_cfg (dict): The testing config.
+
+        Returns:
+            Tensor: Output segmentation map.
+        """
+
+        x = self._transform_inputs(inputs)
+        refined_seg_logits = prev_output.clone()
+        for _ in range(test_cfg.subdivision_steps):
+            refined_seg_logits = resize(
+                refined_seg_logits,
+                scale_factor=test_cfg.scale_factor,
+                mode='bilinear',
+                align_corners=self.align_corners)
+            batch_size, channels, height, width = refined_seg_logits.shape
+            point_indices, points = self.get_points_test(
+                refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
+            fine_grained_point_feats = self._get_fine_grained_point_feats(
+                x, points)
+            coarse_point_feats = self._get_coarse_point_feats(
+                prev_output, points)
+            point_logits = self.forward(fine_grained_point_feats,
+                                        coarse_point_feats)
+
+            point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
+            refined_seg_logits = refined_seg_logits.reshape(
+                batch_size, channels, height * width)
+            refined_seg_logits = refined_seg_logits.scatter_(
+                2, point_indices, point_logits)
+            refined_seg_logits = refined_seg_logits.view(
+                batch_size, channels, height, width)
+
+        return refined_seg_logits
+
+    def losses(self, point_logits, point_label):
+        """Compute segmentation loss."""
+        loss = dict()
+        loss['loss_point'] = self.loss_decode(
+            point_logits, point_label, ignore_index=self.ignore_index)
+        loss['acc_point'] = accuracy(point_logits, point_label)
+        return loss
+
+    def get_points_train(self, seg_logits, uncertainty_func, cfg):
+        """Sample points for training.
+
+        Sample points in [0, 1] x [0, 1] coordinate space based on their
+        uncertainty. The uncertainties are calculated for each point using
+        'uncertainty_func' function that takes point's logit prediction as
+        input.
+
+        Args:
+            seg_logits (Tensor): Semantic segmentation logits, shape (
+                batch_size, num_classes, height, width).
+            uncertainty_func (func): uncertainty calculation function.
+            cfg (dict): Training config of point head.
+
+        Returns:
+            point_coords (Tensor): A tensor of shape (batch_size, num_points,
+                2) that contains the coordinates of ``num_points`` sampled
+                points.
+        """
+        num_points = cfg.num_points
+        oversample_ratio = cfg.oversample_ratio
+        importance_sample_ratio = cfg.importance_sample_ratio
+        assert oversample_ratio >= 1
+        assert 0 <= importance_sample_ratio <= 1
+        batch_size = seg_logits.shape[0]
+        num_sampled = int(num_points * oversample_ratio)
+        point_coords = torch.rand(
+            batch_size, num_sampled, 2, device=seg_logits.device)
+        point_logits = point_sample(seg_logits, point_coords)
+        # It is crucial to calculate uncertainty based on the sampled
+        # prediction value for the points. Calculating uncertainties of the
+        # coarse predictions first and sampling them for points leads to
+        # incorrect results.  To illustrate this: assume uncertainty func(
+        # logits)=-abs(logits), a sampled point between two coarse
+        # predictions with -1 and 1 logits has 0 logits, and therefore 0
+        # uncertainty value. However, if we calculate uncertainties for the
+        # coarse predictions first, both will have -1 uncertainty,
+        # and sampled point will get -1 uncertainty.
+        point_uncertainties = uncertainty_func(point_logits)
+        num_uncertain_points = int(importance_sample_ratio * num_points)
+        num_random_points = num_points - num_uncertain_points
+        idx = torch.topk(
+            point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+        shift = num_sampled * torch.arange(
+            batch_size, dtype=torch.long, device=seg_logits.device)
+        idx += shift[:, None]
+        point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
+            batch_size, num_uncertain_points, 2)
+        if num_random_points > 0:
+            rand_point_coords = torch.rand(
+                batch_size, num_random_points, 2, device=seg_logits.device)
+            point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
+        return point_coords
+
+    def get_points_test(self, seg_logits, uncertainty_func, cfg):
+        """Sample points for testing.
+
+        Find ``num_points`` most uncertain points from ``uncertainty_map``.
+
+        Args:
+            seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
+                height, width) for class-specific or class-agnostic prediction.
+            uncertainty_func (func): uncertainty calculation function.
+            cfg (dict): Testing config of point head.
+
+        Returns:
+            point_indices (Tensor): A tensor of shape (batch_size, num_points)
+                that contains indices from [0, height x width) of the most
+                uncertain points.
+            point_coords (Tensor): A tensor of shape (batch_size, num_points,
+                2) that contains [0, 1] x [0, 1] normalized coordinates of the
+                most uncertain points from the ``height x width`` grid .
+        """
+
+        num_points = cfg.subdivision_num_points
+        uncertainty_map = uncertainty_func(seg_logits)
+        batch_size, _, height, width = uncertainty_map.shape
+        h_step = 1.0 / height
+        w_step = 1.0 / width
+
+        uncertainty_map = uncertainty_map.view(batch_size, height * width)
+        num_points = min(height * width, num_points)
+        point_indices = uncertainty_map.topk(num_points, dim=1)[1]
+        point_coords = torch.zeros(
+            batch_size,
+            num_points,
+            2,
+            dtype=torch.float,
+            device=seg_logits.device)
+        point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
+                                                width).float() * w_step
+        point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
+                                                width).float() * h_step
+        return point_indices, point_coords
diff --git a/annotator/uniformer/mmseg/models/decode_heads/psa_head.py b/annotator/uniformer/mmseg/models/decode_heads/psa_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..480dbd1a081262e45bf87e32c4a339ac8f8b4ffb
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/psa_head.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+try:
+    from annotator.uniformer.mmcv.ops import PSAMask
+except ModuleNotFoundError:
+    PSAMask = None
+
+
+@HEADS.register_module()
+class PSAHead(BaseDecodeHead):
+    """Point-wise Spatial Attention Network for Scene Parsing.
+
+    This head is the implementation of `PSANet
+    <https://hszhao.github.io/papers/eccv18_psanet.pdf>`_.
+
+    Args:
+        mask_size (tuple[int]): The PSA mask size. It usually equals input
+            size.
+        psa_type (str): The type of psa module. Options are 'collect',
+            'distribute', 'bi-direction'. Default: 'bi-direction'
+        compact (bool): Whether use compact map for 'collect' mode.
+            Default: True.
+        shrink_factor (int): The downsample factors of psa mask. Default: 2.
+        normalization_factor (float): The normalize factor of attention.
+        psa_softmax (bool): Whether use softmax for attention.
+    """
+
+    def __init__(self,
+                 mask_size,
+                 psa_type='bi-direction',
+                 compact=False,
+                 shrink_factor=2,
+                 normalization_factor=1.0,
+                 psa_softmax=True,
+                 **kwargs):
+        if PSAMask is None:
+            raise RuntimeError('Please install mmcv-full for PSAMask ops')
+        super(PSAHead, self).__init__(**kwargs)
+        assert psa_type in ['collect', 'distribute', 'bi-direction']
+        self.psa_type = psa_type
+        self.compact = compact
+        self.shrink_factor = shrink_factor
+        self.mask_size = mask_size
+        mask_h, mask_w = mask_size
+        self.psa_softmax = psa_softmax
+        if normalization_factor is None:
+            normalization_factor = mask_h * mask_w
+        self.normalization_factor = normalization_factor
+
+        self.reduce = ConvModule(
+            self.in_channels,
+            self.channels,
+            kernel_size=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.attention = nn.Sequential(
+            ConvModule(
+                self.channels,
+                self.channels,
+                kernel_size=1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg),
+            nn.Conv2d(
+                self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+        if psa_type == 'bi-direction':
+            self.reduce_p = ConvModule(
+                self.in_channels,
+                self.channels,
+                kernel_size=1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg)
+            self.attention_p = nn.Sequential(
+                ConvModule(
+                    self.channels,
+                    self.channels,
+                    kernel_size=1,
+                    conv_cfg=self.conv_cfg,
+                    norm_cfg=self.norm_cfg,
+                    act_cfg=self.act_cfg),
+                nn.Conv2d(
+                    self.channels, mask_h * mask_w, kernel_size=1, bias=False))
+            self.psamask_collect = PSAMask('collect', mask_size)
+            self.psamask_distribute = PSAMask('distribute', mask_size)
+        else:
+            self.psamask = PSAMask(psa_type, mask_size)
+        self.proj = ConvModule(
+            self.channels * (2 if psa_type == 'bi-direction' else 1),
+            self.in_channels,
+            kernel_size=1,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        self.bottleneck = ConvModule(
+            self.in_channels * 2,
+            self.channels,
+            kernel_size=3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        identity = x
+        align_corners = self.align_corners
+        if self.psa_type in ['collect', 'distribute']:
+            out = self.reduce(x)
+            n, c, h, w = out.size()
+            if self.shrink_factor != 1:
+                if h % self.shrink_factor and w % self.shrink_factor:
+                    h = (h - 1) // self.shrink_factor + 1
+                    w = (w - 1) // self.shrink_factor + 1
+                    align_corners = True
+                else:
+                    h = h // self.shrink_factor
+                    w = w // self.shrink_factor
+                    align_corners = False
+                out = resize(
+                    out,
+                    size=(h, w),
+                    mode='bilinear',
+                    align_corners=align_corners)
+            y = self.attention(out)
+            if self.compact:
+                if self.psa_type == 'collect':
+                    y = y.view(n, h * w,
+                               h * w).transpose(1, 2).view(n, h * w, h, w)
+            else:
+                y = self.psamask(y)
+            if self.psa_softmax:
+                y = F.softmax(y, dim=1)
+            out = torch.bmm(
+                out.view(n, c, h * w), y.view(n, h * w, h * w)).view(
+                    n, c, h, w) * (1.0 / self.normalization_factor)
+        else:
+            x_col = self.reduce(x)
+            x_dis = self.reduce_p(x)
+            n, c, h, w = x_col.size()
+            if self.shrink_factor != 1:
+                if h % self.shrink_factor and w % self.shrink_factor:
+                    h = (h - 1) // self.shrink_factor + 1
+                    w = (w - 1) // self.shrink_factor + 1
+                    align_corners = True
+                else:
+                    h = h // self.shrink_factor
+                    w = w // self.shrink_factor
+                    align_corners = False
+                x_col = resize(
+                    x_col,
+                    size=(h, w),
+                    mode='bilinear',
+                    align_corners=align_corners)
+                x_dis = resize(
+                    x_dis,
+                    size=(h, w),
+                    mode='bilinear',
+                    align_corners=align_corners)
+            y_col = self.attention(x_col)
+            y_dis = self.attention_p(x_dis)
+            if self.compact:
+                y_dis = y_dis.view(n, h * w,
+                                   h * w).transpose(1, 2).view(n, h * w, h, w)
+            else:
+                y_col = self.psamask_collect(y_col)
+                y_dis = self.psamask_distribute(y_dis)
+            if self.psa_softmax:
+                y_col = F.softmax(y_col, dim=1)
+                y_dis = F.softmax(y_dis, dim=1)
+            x_col = torch.bmm(
+                x_col.view(n, c, h * w), y_col.view(n, h * w, h * w)).view(
+                    n, c, h, w) * (1.0 / self.normalization_factor)
+            x_dis = torch.bmm(
+                x_dis.view(n, c, h * w), y_dis.view(n, h * w, h * w)).view(
+                    n, c, h, w) * (1.0 / self.normalization_factor)
+            out = torch.cat([x_col, x_dis], 1)
+        out = self.proj(out)
+        out = resize(
+            out,
+            size=identity.shape[2:],
+            mode='bilinear',
+            align_corners=align_corners)
+        out = self.bottleneck(torch.cat((identity, out), dim=1))
+        out = self.cls_seg(out)
+        return out
diff --git a/annotator/uniformer/mmseg/models/decode_heads/psp_head.py b/annotator/uniformer/mmseg/models/decode_heads/psp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5f1e71c70c3a20f4007c263ec471a87bb214a48
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/psp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+
+
+class PPM(nn.ModuleList):
+    """Pooling Pyramid Module used in PSPNet.
+
+    Args:
+        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+            Module.
+        in_channels (int): Input channels.
+        channels (int): Channels after modules, before conv_seg.
+        conv_cfg (dict|None): Config of conv layers.
+        norm_cfg (dict|None): Config of norm layers.
+        act_cfg (dict): Config of activation layers.
+        align_corners (bool): align_corners argument of F.interpolate.
+    """
+
+    def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
+                 act_cfg, align_corners):
+        super(PPM, self).__init__()
+        self.pool_scales = pool_scales
+        self.align_corners = align_corners
+        self.in_channels = in_channels
+        self.channels = channels
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        for pool_scale in pool_scales:
+            self.append(
+                nn.Sequential(
+                    nn.AdaptiveAvgPool2d(pool_scale),
+                    ConvModule(
+                        self.in_channels,
+                        self.channels,
+                        1,
+                        conv_cfg=self.conv_cfg,
+                        norm_cfg=self.norm_cfg,
+                        act_cfg=self.act_cfg)))
+
+    def forward(self, x):
+        """Forward function."""
+        ppm_outs = []
+        for ppm in self:
+            ppm_out = ppm(x)
+            upsampled_ppm_out = resize(
+                ppm_out,
+                size=x.size()[2:],
+                mode='bilinear',
+                align_corners=self.align_corners)
+            ppm_outs.append(upsampled_ppm_out)
+        return ppm_outs
+
+
+@HEADS.register_module()
+class PSPHead(BaseDecodeHead):
+    """Pyramid Scene Parsing Network.
+
+    This head is the implementation of
+    `PSPNet <https://arxiv.org/abs/1612.01105>`_.
+
+    Args:
+        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+            Module. Default: (1, 2, 3, 6).
+    """
+
+    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+        super(PSPHead, self).__init__(**kwargs)
+        assert isinstance(pool_scales, (list, tuple))
+        self.pool_scales = pool_scales
+        self.psp_modules = PPM(
+            self.pool_scales,
+            self.in_channels,
+            self.channels,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg,
+            align_corners=self.align_corners)
+        self.bottleneck = ConvModule(
+            self.in_channels + len(pool_scales) * self.channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        psp_outs = [x]
+        psp_outs.extend(self.psp_modules(x))
+        psp_outs = torch.cat(psp_outs, dim=1)
+        output = self.bottleneck(psp_outs)
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py b/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3339a7ac56e77dfc638e9bffb557d4699148686b
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/sep_aspp_head.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .aspp_head import ASPPHead, ASPPModule
+
+
+class DepthwiseSeparableASPPModule(ASPPModule):
+    """Atrous Spatial Pyramid Pooling (ASPP) Module with depthwise separable
+    conv."""
+
+    def __init__(self, **kwargs):
+        super(DepthwiseSeparableASPPModule, self).__init__(**kwargs)
+        for i, dilation in enumerate(self.dilations):
+            if dilation > 1:
+                self[i] = DepthwiseSeparableConvModule(
+                    self.in_channels,
+                    self.channels,
+                    3,
+                    dilation=dilation,
+                    padding=dilation,
+                    norm_cfg=self.norm_cfg,
+                    act_cfg=self.act_cfg)
+
+
+@HEADS.register_module()
+class DepthwiseSeparableASPPHead(ASPPHead):
+    """Encoder-Decoder with Atrous Separable Convolution for Semantic Image
+    Segmentation.
+
+    This head is the implementation of `DeepLabV3+
+    <https://arxiv.org/abs/1802.02611>`_.
+
+    Args:
+        c1_in_channels (int): The input channels of c1 decoder. If is 0,
+            the no decoder will be used.
+        c1_channels (int): The intermediate channels of c1 decoder.
+    """
+
+    def __init__(self, c1_in_channels, c1_channels, **kwargs):
+        super(DepthwiseSeparableASPPHead, self).__init__(**kwargs)
+        assert c1_in_channels >= 0
+        self.aspp_modules = DepthwiseSeparableASPPModule(
+            dilations=self.dilations,
+            in_channels=self.in_channels,
+            channels=self.channels,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        if c1_in_channels > 0:
+            self.c1_bottleneck = ConvModule(
+                c1_in_channels,
+                c1_channels,
+                1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg)
+        else:
+            self.c1_bottleneck = None
+        self.sep_bottleneck = nn.Sequential(
+            DepthwiseSeparableConvModule(
+                self.channels + c1_channels,
+                self.channels,
+                3,
+                padding=1,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg),
+            DepthwiseSeparableConvModule(
+                self.channels,
+                self.channels,
+                3,
+                padding=1,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg))
+
+    def forward(self, inputs):
+        """Forward function."""
+        x = self._transform_inputs(inputs)
+        aspp_outs = [
+            resize(
+                self.image_pool(x),
+                size=x.size()[2:],
+                mode='bilinear',
+                align_corners=self.align_corners)
+        ]
+        aspp_outs.extend(self.aspp_modules(x))
+        aspp_outs = torch.cat(aspp_outs, dim=1)
+        output = self.bottleneck(aspp_outs)
+        if self.c1_bottleneck is not None:
+            c1_output = self.c1_bottleneck(inputs[0])
+            output = resize(
+                input=output,
+                size=c1_output.shape[2:],
+                mode='bilinear',
+                align_corners=self.align_corners)
+            output = torch.cat([output, c1_output], dim=1)
+        output = self.sep_bottleneck(output)
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py b/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0986143fa4f2bd36f5271354fe5f843f35b9e6f
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/sep_fcn_head.py
@@ -0,0 +1,51 @@
+from annotator.uniformer.mmcv.cnn import DepthwiseSeparableConvModule
+
+from ..builder import HEADS
+from .fcn_head import FCNHead
+
+
+@HEADS.register_module()
+class DepthwiseSeparableFCNHead(FCNHead):
+    """Depthwise-Separable Fully Convolutional Network for Semantic
+    Segmentation.
+
+    This head is implemented according to Fast-SCNN paper.
+    Args:
+        in_channels(int): Number of output channels of FFM.
+        channels(int): Number of middle-stage channels in the decode head.
+        concat_input(bool): Whether to concatenate original decode input into
+            the result of several consecutive convolution layers.
+            Default: True.
+        num_classes(int): Used to determine the dimension of
+            final prediction tensor.
+        in_index(int): Correspond with 'out_indices' in FastSCNN backbone.
+        norm_cfg (dict | None): Config of norm layers.
+        align_corners (bool): align_corners argument of F.interpolate.
+            Default: False.
+        loss_decode(dict): Config of loss type and some
+            relevant additional options.
+    """
+
+    def __init__(self, **kwargs):
+        super(DepthwiseSeparableFCNHead, self).__init__(**kwargs)
+        self.convs[0] = DepthwiseSeparableConvModule(
+            self.in_channels,
+            self.channels,
+            kernel_size=self.kernel_size,
+            padding=self.kernel_size // 2,
+            norm_cfg=self.norm_cfg)
+        for i in range(1, self.num_convs):
+            self.convs[i] = DepthwiseSeparableConvModule(
+                self.channels,
+                self.channels,
+                kernel_size=self.kernel_size,
+                padding=self.kernel_size // 2,
+                norm_cfg=self.norm_cfg)
+
+        if self.concat_input:
+            self.conv_cat = DepthwiseSeparableConvModule(
+                self.in_channels + self.channels,
+                self.channels,
+                kernel_size=self.kernel_size,
+                padding=self.kernel_size // 2,
+                norm_cfg=self.norm_cfg)
diff --git a/annotator/uniformer/mmseg/models/decode_heads/uper_head.py b/annotator/uniformer/mmseg/models/decode_heads/uper_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1301b706b0d83ed714bbdee8ee24693f150455
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/decode_heads/uper_head.py
@@ -0,0 +1,126 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from annotator.uniformer.mmseg.ops import resize
+from ..builder import HEADS
+from .decode_head import BaseDecodeHead
+from .psp_head import PPM
+
+
+@HEADS.register_module()
+class UPerHead(BaseDecodeHead):
+    """Unified Perceptual Parsing for Scene Understanding.
+
+    This head is the implementation of `UPerNet
+    <https://arxiv.org/abs/1807.10221>`_.
+
+    Args:
+        pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
+            Module applied on the last feature. Default: (1, 2, 3, 6).
+    """
+
+    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
+        super(UPerHead, self).__init__(
+            input_transform='multiple_select', **kwargs)
+        # PSP Module
+        self.psp_modules = PPM(
+            pool_scales,
+            self.in_channels[-1],
+            self.channels,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg,
+            align_corners=self.align_corners)
+        self.bottleneck = ConvModule(
+            self.in_channels[-1] + len(pool_scales) * self.channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+        # FPN Module
+        self.lateral_convs = nn.ModuleList()
+        self.fpn_convs = nn.ModuleList()
+        for in_channels in self.in_channels[:-1]:  # skip the top layer
+            l_conv = ConvModule(
+                in_channels,
+                self.channels,
+                1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg,
+                inplace=False)
+            fpn_conv = ConvModule(
+                self.channels,
+                self.channels,
+                3,
+                padding=1,
+                conv_cfg=self.conv_cfg,
+                norm_cfg=self.norm_cfg,
+                act_cfg=self.act_cfg,
+                inplace=False)
+            self.lateral_convs.append(l_conv)
+            self.fpn_convs.append(fpn_conv)
+
+        self.fpn_bottleneck = ConvModule(
+            len(self.in_channels) * self.channels,
+            self.channels,
+            3,
+            padding=1,
+            conv_cfg=self.conv_cfg,
+            norm_cfg=self.norm_cfg,
+            act_cfg=self.act_cfg)
+
+    def psp_forward(self, inputs):
+        """Forward function of PSP module."""
+        x = inputs[-1]
+        psp_outs = [x]
+        psp_outs.extend(self.psp_modules(x))
+        psp_outs = torch.cat(psp_outs, dim=1)
+        output = self.bottleneck(psp_outs)
+
+        return output
+
+    def forward(self, inputs):
+        """Forward function."""
+
+        inputs = self._transform_inputs(inputs)
+
+        # build laterals
+        laterals = [
+            lateral_conv(inputs[i])
+            for i, lateral_conv in enumerate(self.lateral_convs)
+        ]
+
+        laterals.append(self.psp_forward(inputs))
+
+        # build top-down path
+        used_backbone_levels = len(laterals)
+        for i in range(used_backbone_levels - 1, 0, -1):
+            prev_shape = laterals[i - 1].shape[2:]
+            laterals[i - 1] += resize(
+                laterals[i],
+                size=prev_shape,
+                mode='bilinear',
+                align_corners=self.align_corners)
+
+        # build outputs
+        fpn_outs = [
+            self.fpn_convs[i](laterals[i])
+            for i in range(used_backbone_levels - 1)
+        ]
+        # append psp feature
+        fpn_outs.append(laterals[-1])
+
+        for i in range(used_backbone_levels - 1, 0, -1):
+            fpn_outs[i] = resize(
+                fpn_outs[i],
+                size=fpn_outs[0].shape[2:],
+                mode='bilinear',
+                align_corners=self.align_corners)
+        fpn_outs = torch.cat(fpn_outs, dim=1)
+        output = self.fpn_bottleneck(fpn_outs)
+        output = self.cls_seg(output)
+        return output
diff --git a/annotator/uniformer/mmseg/models/losses/__init__.py b/annotator/uniformer/mmseg/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..beca72045694273d63465bac2f27dbc6672271db
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/__init__.py
@@ -0,0 +1,12 @@
+from .accuracy import Accuracy, accuracy
+from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
+                                 cross_entropy, mask_cross_entropy)
+from .dice_loss import DiceLoss
+from .lovasz_loss import LovaszLoss
+from .utils import reduce_loss, weight_reduce_loss, weighted_loss
+
+__all__ = [
+    'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
+    'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
+    'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss'
+]
diff --git a/annotator/uniformer/mmseg/models/losses/accuracy.py b/annotator/uniformer/mmseg/models/losses/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0fd2e7e74a0f721c4a814c09d6e453e5956bb38
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/accuracy.py
@@ -0,0 +1,78 @@
+import torch.nn as nn
+
+
+def accuracy(pred, target, topk=1, thresh=None):
+    """Calculate accuracy according to the prediction and target.
+
+    Args:
+        pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
+        target (torch.Tensor): The target of each prediction, shape (N, , ...)
+        topk (int | tuple[int], optional): If the predictions in ``topk``
+            matches the target, the predictions will be regarded as
+            correct ones. Defaults to 1.
+        thresh (float, optional): If not None, predictions with scores under
+            this threshold are considered incorrect. Default to None.
+
+    Returns:
+        float | tuple[float]: If the input ``topk`` is a single integer,
+            the function will return a single float as accuracy. If
+            ``topk`` is a tuple containing multiple integers, the
+            function will return a tuple containing accuracies of
+            each ``topk`` number.
+    """
+    assert isinstance(topk, (int, tuple))
+    if isinstance(topk, int):
+        topk = (topk, )
+        return_single = True
+    else:
+        return_single = False
+
+    maxk = max(topk)
+    if pred.size(0) == 0:
+        accu = [pred.new_tensor(0.) for i in range(len(topk))]
+        return accu[0] if return_single else accu
+    assert pred.ndim == target.ndim + 1
+    assert pred.size(0) == target.size(0)
+    assert maxk <= pred.size(1), \
+        f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
+    pred_value, pred_label = pred.topk(maxk, dim=1)
+    # transpose to shape (maxk, N, ...)
+    pred_label = pred_label.transpose(0, 1)
+    correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label))
+    if thresh is not None:
+        # Only prediction values larger than thresh are counted as correct
+        correct = correct & (pred_value > thresh).t()
+    res = []
+    for k in topk:
+        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+        res.append(correct_k.mul_(100.0 / target.numel()))
+    return res[0] if return_single else res
+
+
+class Accuracy(nn.Module):
+    """Accuracy calculation module."""
+
+    def __init__(self, topk=(1, ), thresh=None):
+        """Module to calculate the accuracy.
+
+        Args:
+            topk (tuple, optional): The criterion used to calculate the
+                accuracy. Defaults to (1,).
+            thresh (float, optional): If not None, predictions with scores
+                under this threshold are considered incorrect. Default to None.
+        """
+        super().__init__()
+        self.topk = topk
+        self.thresh = thresh
+
+    def forward(self, pred, target):
+        """Forward function to calculate accuracy.
+
+        Args:
+            pred (torch.Tensor): Prediction of models.
+            target (torch.Tensor): Target for each prediction.
+
+        Returns:
+            tuple[float]: The accuracies under different topk criterions.
+        """
+        return accuracy(pred, target, self.topk, self.thresh)
diff --git a/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py b/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c0790c98616bb69621deed55547fc04c7392ef
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/cross_entropy_loss.py
@@ -0,0 +1,198 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+
+
+def cross_entropy(pred,
+                  label,
+                  weight=None,
+                  class_weight=None,
+                  reduction='mean',
+                  avg_factor=None,
+                  ignore_index=-100):
+    """The wrapper function for :func:`F.cross_entropy`"""
+    # class_weight is a manual rescaling weight given to each class.
+    # If given, has to be a Tensor of size C element-wise losses
+    loss = F.cross_entropy(
+        pred,
+        label,
+        weight=class_weight,
+        reduction='none',
+        ignore_index=ignore_index)
+
+    # apply weights and do the reduction
+    if weight is not None:
+        weight = weight.float()
+    loss = weight_reduce_loss(
+        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+
+    return loss
+
+
+def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
+    """Expand onehot labels to match the size of prediction."""
+    bin_labels = labels.new_zeros(target_shape)
+    valid_mask = (labels >= 0) & (labels != ignore_index)
+    inds = torch.nonzero(valid_mask, as_tuple=True)
+
+    if inds[0].numel() > 0:
+        if labels.dim() == 3:
+            bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
+        else:
+            bin_labels[inds[0], labels[valid_mask]] = 1
+
+    valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
+    if label_weights is None:
+        bin_label_weights = valid_mask
+    else:
+        bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
+        bin_label_weights *= valid_mask
+
+    return bin_labels, bin_label_weights
+
+
+def binary_cross_entropy(pred,
+                         label,
+                         weight=None,
+                         reduction='mean',
+                         avg_factor=None,
+                         class_weight=None,
+                         ignore_index=255):
+    """Calculate the binary CrossEntropy loss.
+
+    Args:
+        pred (torch.Tensor): The prediction with shape (N, 1).
+        label (torch.Tensor): The learning label of the prediction.
+        weight (torch.Tensor, optional): Sample-wise loss weight.
+        reduction (str, optional): The method used to reduce the loss.
+            Options are "none", "mean" and "sum".
+        avg_factor (int, optional): Average factor that is used to average
+            the loss. Defaults to None.
+        class_weight (list[float], optional): The weight for each class.
+        ignore_index (int | None): The label index to be ignored. Default: 255
+
+    Returns:
+        torch.Tensor: The calculated loss
+    """
+    if pred.dim() != label.dim():
+        assert (pred.dim() == 2 and label.dim() == 1) or (
+                pred.dim() == 4 and label.dim() == 3), \
+            'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
+            'H, W], label shape [N, H, W] are supported'
+        label, weight = _expand_onehot_labels(label, weight, pred.shape,
+                                              ignore_index)
+
+    # weighted element-wise losses
+    if weight is not None:
+        weight = weight.float()
+    loss = F.binary_cross_entropy_with_logits(
+        pred, label.float(), pos_weight=class_weight, reduction='none')
+    # do the reduction for the weighted loss
+    loss = weight_reduce_loss(
+        loss, weight, reduction=reduction, avg_factor=avg_factor)
+
+    return loss
+
+
+def mask_cross_entropy(pred,
+                       target,
+                       label,
+                       reduction='mean',
+                       avg_factor=None,
+                       class_weight=None,
+                       ignore_index=None):
+    """Calculate the CrossEntropy loss for masks.
+
+    Args:
+        pred (torch.Tensor): The prediction with shape (N, C), C is the number
+            of classes.
+        target (torch.Tensor): The learning label of the prediction.
+        label (torch.Tensor): ``label`` indicates the class label of the mask'
+            corresponding object. This will be used to select the mask in the
+            of the class which the object belongs to when the mask prediction
+            if not class-agnostic.
+        reduction (str, optional): The method used to reduce the loss.
+            Options are "none", "mean" and "sum".
+        avg_factor (int, optional): Average factor that is used to average
+            the loss. Defaults to None.
+        class_weight (list[float], optional): The weight for each class.
+        ignore_index (None): Placeholder, to be consistent with other loss.
+            Default: None.
+
+    Returns:
+        torch.Tensor: The calculated loss
+    """
+    assert ignore_index is None, 'BCE loss does not support ignore_index'
+    # TODO: handle these two reserved arguments
+    assert reduction == 'mean' and avg_factor is None
+    num_rois = pred.size()[0]
+    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+    pred_slice = pred[inds, label].squeeze(1)
+    return F.binary_cross_entropy_with_logits(
+        pred_slice, target, weight=class_weight, reduction='mean')[None]
+
+
+@LOSSES.register_module()
+class CrossEntropyLoss(nn.Module):
+    """CrossEntropyLoss.
+
+    Args:
+        use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+            of softmax. Defaults to False.
+        use_mask (bool, optional): Whether to use mask cross entropy loss.
+            Defaults to False.
+        reduction (str, optional): . Defaults to 'mean'.
+            Options are "none", "mean" and "sum".
+        class_weight (list[float] | str, optional): Weight of each class. If in
+            str format, read them from a file. Defaults to None.
+        loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+    """
+
+    def __init__(self,
+                 use_sigmoid=False,
+                 use_mask=False,
+                 reduction='mean',
+                 class_weight=None,
+                 loss_weight=1.0):
+        super(CrossEntropyLoss, self).__init__()
+        assert (use_sigmoid is False) or (use_mask is False)
+        self.use_sigmoid = use_sigmoid
+        self.use_mask = use_mask
+        self.reduction = reduction
+        self.loss_weight = loss_weight
+        self.class_weight = get_class_weight(class_weight)
+
+        if self.use_sigmoid:
+            self.cls_criterion = binary_cross_entropy
+        elif self.use_mask:
+            self.cls_criterion = mask_cross_entropy
+        else:
+            self.cls_criterion = cross_entropy
+
+    def forward(self,
+                cls_score,
+                label,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None,
+                **kwargs):
+        """Forward function."""
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
+        if self.class_weight is not None:
+            class_weight = cls_score.new_tensor(self.class_weight)
+        else:
+            class_weight = None
+        loss_cls = self.loss_weight * self.cls_criterion(
+            cls_score,
+            label,
+            weight,
+            class_weight=class_weight,
+            reduction=reduction,
+            avg_factor=avg_factor,
+            **kwargs)
+        return loss_cls
diff --git a/annotator/uniformer/mmseg/models/losses/dice_loss.py b/annotator/uniformer/mmseg/models/losses/dice_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..27a77b962d7d8b3079c7d6cd9db52280c6fb4970
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/dice_loss.py
@@ -0,0 +1,119 @@
+"""Modified from https://github.com/LikeLy-Journey/SegmenTron/blob/master/
+segmentron/solver/loss.py (Apache-2.0 License)"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weighted_loss
+
+
+@weighted_loss
+def dice_loss(pred,
+              target,
+              valid_mask,
+              smooth=1,
+              exponent=2,
+              class_weight=None,
+              ignore_index=255):
+    assert pred.shape[0] == target.shape[0]
+    total_loss = 0
+    num_classes = pred.shape[1]
+    for i in range(num_classes):
+        if i != ignore_index:
+            dice_loss = binary_dice_loss(
+                pred[:, i],
+                target[..., i],
+                valid_mask=valid_mask,
+                smooth=smooth,
+                exponent=exponent)
+            if class_weight is not None:
+                dice_loss *= class_weight[i]
+            total_loss += dice_loss
+    return total_loss / num_classes
+
+
+@weighted_loss
+def binary_dice_loss(pred, target, valid_mask, smooth=1, exponent=2, **kwards):
+    assert pred.shape[0] == target.shape[0]
+    pred = pred.reshape(pred.shape[0], -1)
+    target = target.reshape(target.shape[0], -1)
+    valid_mask = valid_mask.reshape(valid_mask.shape[0], -1)
+
+    num = torch.sum(torch.mul(pred, target) * valid_mask, dim=1) * 2 + smooth
+    den = torch.sum(pred.pow(exponent) + target.pow(exponent), dim=1) + smooth
+
+    return 1 - num / den
+
+
+@LOSSES.register_module()
+class DiceLoss(nn.Module):
+    """DiceLoss.
+
+    This loss is proposed in `V-Net: Fully Convolutional Neural Networks for
+    Volumetric Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
+
+    Args:
+        loss_type (str, optional): Binary or multi-class loss.
+            Default: 'multi_class'. Options are "binary" and "multi_class".
+        smooth (float): A float number to smooth loss, and avoid NaN error.
+            Default: 1
+        exponent (float): An float number to calculate denominator
+            value: \\sum{x^exponent} + \\sum{y^exponent}. Default: 2.
+        reduction (str, optional): The method used to reduce the loss. Options
+            are "none", "mean" and "sum". This parameter only works when
+            per_image is True. Default: 'mean'.
+        class_weight (list[float] | str, optional): Weight of each class. If in
+            str format, read them from a file. Defaults to None.
+        loss_weight (float, optional): Weight of the loss. Default to 1.0.
+        ignore_index (int | None): The label index to be ignored. Default: 255.
+    """
+
+    def __init__(self,
+                 smooth=1,
+                 exponent=2,
+                 reduction='mean',
+                 class_weight=None,
+                 loss_weight=1.0,
+                 ignore_index=255,
+                 **kwards):
+        super(DiceLoss, self).__init__()
+        self.smooth = smooth
+        self.exponent = exponent
+        self.reduction = reduction
+        self.class_weight = get_class_weight(class_weight)
+        self.loss_weight = loss_weight
+        self.ignore_index = ignore_index
+
+    def forward(self,
+                pred,
+                target,
+                avg_factor=None,
+                reduction_override=None,
+                **kwards):
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
+        if self.class_weight is not None:
+            class_weight = pred.new_tensor(self.class_weight)
+        else:
+            class_weight = None
+
+        pred = F.softmax(pred, dim=1)
+        num_classes = pred.shape[1]
+        one_hot_target = F.one_hot(
+            torch.clamp(target.long(), 0, num_classes - 1),
+            num_classes=num_classes)
+        valid_mask = (target != self.ignore_index).long()
+
+        loss = self.loss_weight * dice_loss(
+            pred,
+            one_hot_target,
+            valid_mask=valid_mask,
+            reduction=reduction,
+            avg_factor=avg_factor,
+            smooth=self.smooth,
+            exponent=self.exponent,
+            class_weight=class_weight,
+            ignore_index=self.ignore_index)
+        return loss
diff --git a/annotator/uniformer/mmseg/models/losses/lovasz_loss.py b/annotator/uniformer/mmseg/models/losses/lovasz_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6badb67f6d987b59fb07aa97caaaf89896e27a8d
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/lovasz_loss.py
@@ -0,0 +1,303 @@
+"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
+ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
+Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
+
+import annotator.uniformer.mmcv as mmcv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import get_class_weight, weight_reduce_loss
+
+
+def lovasz_grad(gt_sorted):
+    """Computes gradient of the Lovasz extension w.r.t sorted errors.
+
+    See Alg. 1 in paper.
+    """
+    p = len(gt_sorted)
+    gts = gt_sorted.sum()
+    intersection = gts - gt_sorted.float().cumsum(0)
+    union = gts + (1 - gt_sorted).float().cumsum(0)
+    jaccard = 1. - intersection / union
+    if p > 1:  # cover 1-pixel case
+        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
+    return jaccard
+
+
+def flatten_binary_logits(logits, labels, ignore_index=None):
+    """Flattens predictions in the batch (binary case) Remove labels equal to
+    'ignore_index'."""
+    logits = logits.view(-1)
+    labels = labels.view(-1)
+    if ignore_index is None:
+        return logits, labels
+    valid = (labels != ignore_index)
+    vlogits = logits[valid]
+    vlabels = labels[valid]
+    return vlogits, vlabels
+
+
+def flatten_probs(probs, labels, ignore_index=None):
+    """Flattens predictions in the batch."""
+    if probs.dim() == 3:
+        # assumes output of a sigmoid layer
+        B, H, W = probs.size()
+        probs = probs.view(B, 1, H, W)
+    B, C, H, W = probs.size()
+    probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B*H*W, C=P,C
+    labels = labels.view(-1)
+    if ignore_index is None:
+        return probs, labels
+    valid = (labels != ignore_index)
+    vprobs = probs[valid.nonzero().squeeze()]
+    vlabels = labels[valid]
+    return vprobs, vlabels
+
+
+def lovasz_hinge_flat(logits, labels):
+    """Binary Lovasz hinge loss.
+
+    Args:
+        logits (torch.Tensor): [P], logits at each prediction
+            (between -infty and +infty).
+        labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
+
+    Returns:
+        torch.Tensor: The calculated loss.
+    """
+    if len(labels) == 0:
+        # only void pixels, the gradients should be 0
+        return logits.sum() * 0.
+    signs = 2. * labels.float() - 1.
+    errors = (1. - logits * signs)
+    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
+    perm = perm.data
+    gt_sorted = labels[perm]
+    grad = lovasz_grad(gt_sorted)
+    loss = torch.dot(F.relu(errors_sorted), grad)
+    return loss
+
+
+def lovasz_hinge(logits,
+                 labels,
+                 classes='present',
+                 per_image=False,
+                 class_weight=None,
+                 reduction='mean',
+                 avg_factor=None,
+                 ignore_index=255):
+    """Binary Lovasz hinge loss.
+
+    Args:
+        logits (torch.Tensor): [B, H, W], logits at each pixel
+            (between -infty and +infty).
+        labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
+        classes (str | list[int], optional): Placeholder, to be consistent with
+            other loss. Default: None.
+        per_image (bool, optional): If per_image is True, compute the loss per
+            image instead of per batch. Default: False.
+        class_weight (list[float], optional): Placeholder, to be consistent
+            with other loss. Default: None.
+        reduction (str, optional): The method used to reduce the loss. Options
+            are "none", "mean" and "sum". This parameter only works when
+            per_image is True. Default: 'mean'.
+        avg_factor (int, optional): Average factor that is used to average
+            the loss. This parameter only works when per_image is True.
+            Default: None.
+        ignore_index (int | None): The label index to be ignored. Default: 255.
+
+    Returns:
+        torch.Tensor: The calculated loss.
+    """
+    if per_image:
+        loss = [
+            lovasz_hinge_flat(*flatten_binary_logits(
+                logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
+            for logit, label in zip(logits, labels)
+        ]
+        loss = weight_reduce_loss(
+            torch.stack(loss), None, reduction, avg_factor)
+    else:
+        loss = lovasz_hinge_flat(
+            *flatten_binary_logits(logits, labels, ignore_index))
+    return loss
+
+
+def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
+    """Multi-class Lovasz-Softmax loss.
+
+    Args:
+        probs (torch.Tensor): [P, C], class probabilities at each prediction
+            (between 0 and 1).
+        labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
+        classes (str | list[int], optional): Classes chosen to calculate loss.
+            'all' for all classes, 'present' for classes present in labels, or
+            a list of classes to average. Default: 'present'.
+        class_weight (list[float], optional): The weight for each class.
+            Default: None.
+
+    Returns:
+        torch.Tensor: The calculated loss.
+    """
+    if probs.numel() == 0:
+        # only void pixels, the gradients should be 0
+        return probs * 0.
+    C = probs.size(1)
+    losses = []
+    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
+    for c in class_to_sum:
+        fg = (labels == c).float()  # foreground for class c
+        if (classes == 'present' and fg.sum() == 0):
+            continue
+        if C == 1:
+            if len(classes) > 1:
+                raise ValueError('Sigmoid output possible only with 1 class')
+            class_pred = probs[:, 0]
+        else:
+            class_pred = probs[:, c]
+        errors = (fg - class_pred).abs()
+        errors_sorted, perm = torch.sort(errors, 0, descending=True)
+        perm = perm.data
+        fg_sorted = fg[perm]
+        loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
+        if class_weight is not None:
+            loss *= class_weight[c]
+        losses.append(loss)
+    return torch.stack(losses).mean()
+
+
+def lovasz_softmax(probs,
+                   labels,
+                   classes='present',
+                   per_image=False,
+                   class_weight=None,
+                   reduction='mean',
+                   avg_factor=None,
+                   ignore_index=255):
+    """Multi-class Lovasz-Softmax loss.
+
+    Args:
+        probs (torch.Tensor): [B, C, H, W], class probabilities at each
+            prediction (between 0 and 1).
+        labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
+            C - 1).
+        classes (str | list[int], optional): Classes chosen to calculate loss.
+            'all' for all classes, 'present' for classes present in labels, or
+            a list of classes to average. Default: 'present'.
+        per_image (bool, optional): If per_image is True, compute the loss per
+            image instead of per batch. Default: False.
+        class_weight (list[float], optional): The weight for each class.
+            Default: None.
+        reduction (str, optional): The method used to reduce the loss. Options
+            are "none", "mean" and "sum". This parameter only works when
+            per_image is True. Default: 'mean'.
+        avg_factor (int, optional): Average factor that is used to average
+            the loss. This parameter only works when per_image is True.
+            Default: None.
+        ignore_index (int | None): The label index to be ignored. Default: 255.
+
+    Returns:
+        torch.Tensor: The calculated loss.
+    """
+
+    if per_image:
+        loss = [
+            lovasz_softmax_flat(
+                *flatten_probs(
+                    prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
+                classes=classes,
+                class_weight=class_weight)
+            for prob, label in zip(probs, labels)
+        ]
+        loss = weight_reduce_loss(
+            torch.stack(loss), None, reduction, avg_factor)
+    else:
+        loss = lovasz_softmax_flat(
+            *flatten_probs(probs, labels, ignore_index),
+            classes=classes,
+            class_weight=class_weight)
+    return loss
+
+
+@LOSSES.register_module()
+class LovaszLoss(nn.Module):
+    """LovaszLoss.
+
+    This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
+    for the optimization of the intersection-over-union measure in neural
+    networks <https://arxiv.org/abs/1705.08790>`_.
+
+    Args:
+        loss_type (str, optional): Binary or multi-class loss.
+            Default: 'multi_class'. Options are "binary" and "multi_class".
+        classes (str | list[int], optional): Classes chosen to calculate loss.
+            'all' for all classes, 'present' for classes present in labels, or
+            a list of classes to average. Default: 'present'.
+        per_image (bool, optional): If per_image is True, compute the loss per
+            image instead of per batch. Default: False.
+        reduction (str, optional): The method used to reduce the loss. Options
+            are "none", "mean" and "sum". This parameter only works when
+            per_image is True. Default: 'mean'.
+        class_weight (list[float] | str, optional): Weight of each class. If in
+            str format, read them from a file. Defaults to None.
+        loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+    """
+
+    def __init__(self,
+                 loss_type='multi_class',
+                 classes='present',
+                 per_image=False,
+                 reduction='mean',
+                 class_weight=None,
+                 loss_weight=1.0):
+        super(LovaszLoss, self).__init__()
+        assert loss_type in ('binary', 'multi_class'), "loss_type should be \
+                                                    'binary' or 'multi_class'."
+
+        if loss_type == 'binary':
+            self.cls_criterion = lovasz_hinge
+        else:
+            self.cls_criterion = lovasz_softmax
+        assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
+        if not per_image:
+            assert reduction == 'none', "reduction should be 'none' when \
+                                                        per_image is False."
+
+        self.classes = classes
+        self.per_image = per_image
+        self.reduction = reduction
+        self.loss_weight = loss_weight
+        self.class_weight = get_class_weight(class_weight)
+
+    def forward(self,
+                cls_score,
+                label,
+                weight=None,
+                avg_factor=None,
+                reduction_override=None,
+                **kwargs):
+        """Forward function."""
+        assert reduction_override in (None, 'none', 'mean', 'sum')
+        reduction = (
+            reduction_override if reduction_override else self.reduction)
+        if self.class_weight is not None:
+            class_weight = cls_score.new_tensor(self.class_weight)
+        else:
+            class_weight = None
+
+        # if multi-class loss, transform logits to probs
+        if self.cls_criterion == lovasz_softmax:
+            cls_score = F.softmax(cls_score, dim=1)
+
+        loss_cls = self.loss_weight * self.cls_criterion(
+            cls_score,
+            label,
+            self.classes,
+            self.per_image,
+            class_weight=class_weight,
+            reduction=reduction,
+            avg_factor=avg_factor,
+            **kwargs)
+        return loss_cls
diff --git a/annotator/uniformer/mmseg/models/losses/utils.py b/annotator/uniformer/mmseg/models/losses/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..85aec9f3045240c3de96a928324ae8f5c3aebe8b
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/losses/utils.py
@@ -0,0 +1,121 @@
+import functools
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch.nn.functional as F
+
+
+def get_class_weight(class_weight):
+    """Get class weight for loss function.
+
+    Args:
+        class_weight (list[float] | str | None): If class_weight is a str,
+            take it as a file name and read from it.
+    """
+    if isinstance(class_weight, str):
+        # take it as a file path
+        if class_weight.endswith('.npy'):
+            class_weight = np.load(class_weight)
+        else:
+            # pkl, json or yaml
+            class_weight = mmcv.load(class_weight)
+
+    return class_weight
+
+
+def reduce_loss(loss, reduction):
+    """Reduce loss as specified.
+
+    Args:
+        loss (Tensor): Elementwise loss tensor.
+        reduction (str): Options are "none", "mean" and "sum".
+
+    Return:
+        Tensor: Reduced loss tensor.
+    """
+    reduction_enum = F._Reduction.get_enum(reduction)
+    # none: 0, elementwise_mean:1, sum: 2
+    if reduction_enum == 0:
+        return loss
+    elif reduction_enum == 1:
+        return loss.mean()
+    elif reduction_enum == 2:
+        return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+    """Apply element-wise weight and reduce loss.
+
+    Args:
+        loss (Tensor): Element-wise loss.
+        weight (Tensor): Element-wise weights.
+        reduction (str): Same as built-in losses of PyTorch.
+        avg_factor (float): Avarage factor when computing the mean of losses.
+
+    Returns:
+        Tensor: Processed loss values.
+    """
+    # if weight is specified, apply element-wise weight
+    if weight is not None:
+        assert weight.dim() == loss.dim()
+        if weight.dim() > 1:
+            assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+        loss = loss * weight
+
+    # if avg_factor is not specified, just reduce the loss
+    if avg_factor is None:
+        loss = reduce_loss(loss, reduction)
+    else:
+        # if reduction is mean, then average the loss by avg_factor
+        if reduction == 'mean':
+            loss = loss.sum() / avg_factor
+        # if reduction is 'none', then do nothing, otherwise raise an error
+        elif reduction != 'none':
+            raise ValueError('avg_factor can not be used with reduction="sum"')
+    return loss
+
+
+def weighted_loss(loss_func):
+    """Create a weighted version of a given loss function.
+
+    To use this decorator, the loss function must have the signature like
+    `loss_func(pred, target, **kwargs)`. The function only needs to compute
+    element-wise loss without any reduction. This decorator will add weight
+    and reduction arguments to the function. The decorated function will have
+    the signature like `loss_func(pred, target, weight=None, reduction='mean',
+    avg_factor=None, **kwargs)`.
+
+    :Example:
+
+    >>> import torch
+    >>> @weighted_loss
+    >>> def l1_loss(pred, target):
+    >>>     return (pred - target).abs()
+
+    >>> pred = torch.Tensor([0, 2, 3])
+    >>> target = torch.Tensor([1, 1, 1])
+    >>> weight = torch.Tensor([1, 0, 1])
+
+    >>> l1_loss(pred, target)
+    tensor(1.3333)
+    >>> l1_loss(pred, target, weight)
+    tensor(1.)
+    >>> l1_loss(pred, target, reduction='none')
+    tensor([1., 1., 2.])
+    >>> l1_loss(pred, target, weight, avg_factor=2)
+    tensor(1.5000)
+    """
+
+    @functools.wraps(loss_func)
+    def wrapper(pred,
+                target,
+                weight=None,
+                reduction='mean',
+                avg_factor=None,
+                **kwargs):
+        # get element-wise loss
+        loss = loss_func(pred, target, **kwargs)
+        loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+        return loss
+
+    return wrapper
diff --git a/annotator/uniformer/mmseg/models/necks/__init__.py b/annotator/uniformer/mmseg/models/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b9d3d5b3fe80247642d962edd6fb787537d01d6
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/necks/__init__.py
@@ -0,0 +1,4 @@
+from .fpn import FPN
+from .multilevel_neck import MultiLevelNeck
+
+__all__ = ['FPN', 'MultiLevelNeck']
diff --git a/annotator/uniformer/mmseg/models/necks/fpn.py b/annotator/uniformer/mmseg/models/necks/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a53b2a69500f8c2edb835abc3ff0ccc2173d1fb1
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/necks/fpn.py
@@ -0,0 +1,212 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule, xavier_init
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class FPN(nn.Module):
+    """Feature Pyramid Network.
+
+    This is an implementation of - Feature Pyramid Networks for Object
+    Detection (https://arxiv.org/abs/1612.03144)
+
+    Args:
+        in_channels (List[int]): Number of input channels per scale.
+        out_channels (int): Number of output channels (used at each scale)
+        num_outs (int): Number of output scales.
+        start_level (int): Index of the start input backbone level used to
+            build the feature pyramid. Default: 0.
+        end_level (int): Index of the end input backbone level (exclusive) to
+            build the feature pyramid. Default: -1, which means the last level.
+        add_extra_convs (bool | str): If bool, it decides whether to add conv
+            layers on top of the original feature maps. Default to False.
+            If True, its actual mode is specified by `extra_convs_on_inputs`.
+            If str, it specifies the source feature map of the extra convs.
+            Only the following options are allowed
+
+            - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+            - 'on_lateral':  Last feature map after lateral convs.
+            - 'on_output': The last output feature map after fpn convs.
+        extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
+            on the original feature from the backbone. If True,
+            it is equivalent to `add_extra_convs='on_input'`. If False, it is
+            equivalent to set `add_extra_convs='on_output'`. Default to True.
+        relu_before_extra_convs (bool): Whether to apply relu before the extra
+            conv. Default: False.
+        no_norm_on_lateral (bool): Whether to apply norm on lateral.
+            Default: False.
+        conv_cfg (dict): Config dict for convolution layer. Default: None.
+        norm_cfg (dict): Config dict for normalization layer. Default: None.
+        act_cfg (str): Config dict for activation layer in ConvModule.
+            Default: None.
+        upsample_cfg (dict): Config dict for interpolate layer.
+            Default: `dict(mode='nearest')`
+
+    Example:
+        >>> import torch
+        >>> in_channels = [2, 3, 5, 7]
+        >>> scales = [340, 170, 84, 43]
+        >>> inputs = [torch.rand(1, c, s, s)
+        ...           for c, s in zip(in_channels, scales)]
+        >>> self = FPN(in_channels, 11, len(in_channels)).eval()
+        >>> outputs = self.forward(inputs)
+        >>> for i in range(len(outputs)):
+        ...     print(f'outputs[{i}].shape = {outputs[i].shape}')
+        outputs[0].shape = torch.Size([1, 11, 340, 340])
+        outputs[1].shape = torch.Size([1, 11, 170, 170])
+        outputs[2].shape = torch.Size([1, 11, 84, 84])
+        outputs[3].shape = torch.Size([1, 11, 43, 43])
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 num_outs,
+                 start_level=0,
+                 end_level=-1,
+                 add_extra_convs=False,
+                 extra_convs_on_inputs=False,
+                 relu_before_extra_convs=False,
+                 no_norm_on_lateral=False,
+                 conv_cfg=None,
+                 norm_cfg=None,
+                 act_cfg=None,
+                 upsample_cfg=dict(mode='nearest')):
+        super(FPN, self).__init__()
+        assert isinstance(in_channels, list)
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.num_ins = len(in_channels)
+        self.num_outs = num_outs
+        self.relu_before_extra_convs = relu_before_extra_convs
+        self.no_norm_on_lateral = no_norm_on_lateral
+        self.fp16_enabled = False
+        self.upsample_cfg = upsample_cfg.copy()
+
+        if end_level == -1:
+            self.backbone_end_level = self.num_ins
+            assert num_outs >= self.num_ins - start_level
+        else:
+            # if end_level < inputs, no extra level is allowed
+            self.backbone_end_level = end_level
+            assert end_level <= len(in_channels)
+            assert num_outs == end_level - start_level
+        self.start_level = start_level
+        self.end_level = end_level
+        self.add_extra_convs = add_extra_convs
+        assert isinstance(add_extra_convs, (str, bool))
+        if isinstance(add_extra_convs, str):
+            # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+            assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+        elif add_extra_convs:  # True
+            if extra_convs_on_inputs:
+                # For compatibility with previous release
+                # TODO: deprecate `extra_convs_on_inputs`
+                self.add_extra_convs = 'on_input'
+            else:
+                self.add_extra_convs = 'on_output'
+
+        self.lateral_convs = nn.ModuleList()
+        self.fpn_convs = nn.ModuleList()
+
+        for i in range(self.start_level, self.backbone_end_level):
+            l_conv = ConvModule(
+                in_channels[i],
+                out_channels,
+                1,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+                act_cfg=act_cfg,
+                inplace=False)
+            fpn_conv = ConvModule(
+                out_channels,
+                out_channels,
+                3,
+                padding=1,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg,
+                inplace=False)
+
+            self.lateral_convs.append(l_conv)
+            self.fpn_convs.append(fpn_conv)
+
+        # add extra conv layers (e.g., RetinaNet)
+        extra_levels = num_outs - self.backbone_end_level + self.start_level
+        if self.add_extra_convs and extra_levels >= 1:
+            for i in range(extra_levels):
+                if i == 0 and self.add_extra_convs == 'on_input':
+                    in_channels = self.in_channels[self.backbone_end_level - 1]
+                else:
+                    in_channels = out_channels
+                extra_fpn_conv = ConvModule(
+                    in_channels,
+                    out_channels,
+                    3,
+                    stride=2,
+                    padding=1,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg,
+                    inplace=False)
+                self.fpn_convs.append(extra_fpn_conv)
+
+    # default init_weights for conv(msra) and norm in ConvModule
+    def init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                xavier_init(m, distribution='uniform')
+
+    def forward(self, inputs):
+        assert len(inputs) == len(self.in_channels)
+
+        # build laterals
+        laterals = [
+            lateral_conv(inputs[i + self.start_level])
+            for i, lateral_conv in enumerate(self.lateral_convs)
+        ]
+
+        # build top-down path
+        used_backbone_levels = len(laterals)
+        for i in range(used_backbone_levels - 1, 0, -1):
+            # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+            #  it cannot co-exist with `size` in `F.interpolate`.
+            if 'scale_factor' in self.upsample_cfg:
+                laterals[i - 1] += F.interpolate(laterals[i],
+                                                 **self.upsample_cfg)
+            else:
+                prev_shape = laterals[i - 1].shape[2:]
+                laterals[i - 1] += F.interpolate(
+                    laterals[i], size=prev_shape, **self.upsample_cfg)
+
+        # build outputs
+        # part 1: from original levels
+        outs = [
+            self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+        ]
+        # part 2: add extra levels
+        if self.num_outs > len(outs):
+            # use max pool to get more levels on top of outputs
+            # (e.g., Faster R-CNN, Mask R-CNN)
+            if not self.add_extra_convs:
+                for i in range(self.num_outs - used_backbone_levels):
+                    outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+            # add conv layers on top of original feature maps (RetinaNet)
+            else:
+                if self.add_extra_convs == 'on_input':
+                    extra_source = inputs[self.backbone_end_level - 1]
+                elif self.add_extra_convs == 'on_lateral':
+                    extra_source = laterals[-1]
+                elif self.add_extra_convs == 'on_output':
+                    extra_source = outs[-1]
+                else:
+                    raise NotImplementedError
+                outs.append(self.fpn_convs[used_backbone_levels](extra_source))
+                for i in range(used_backbone_levels + 1, self.num_outs):
+                    if self.relu_before_extra_convs:
+                        outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+                    else:
+                        outs.append(self.fpn_convs[i](outs[-1]))
+        return tuple(outs)
diff --git a/annotator/uniformer/mmseg/models/necks/multilevel_neck.py b/annotator/uniformer/mmseg/models/necks/multilevel_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..766144d8136326a1fab5906a153a0c0df69b6b60
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/necks/multilevel_neck.py
@@ -0,0 +1,70 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class MultiLevelNeck(nn.Module):
+    """MultiLevelNeck.
+
+    A neck structure connect vit backbone and decoder_heads.
+    Args:
+        in_channels (List[int]): Number of input channels per scale.
+        out_channels (int): Number of output channels (used at each scale).
+        scales (List[int]): Scale factors for each input feature map.
+        norm_cfg (dict): Config dict for normalization layer. Default: None.
+        act_cfg (dict): Config dict for activation layer in ConvModule.
+            Default: None.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 scales=[0.5, 1, 2, 4],
+                 norm_cfg=None,
+                 act_cfg=None):
+        super(MultiLevelNeck, self).__init__()
+        assert isinstance(in_channels, list)
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.scales = scales
+        self.num_outs = len(scales)
+        self.lateral_convs = nn.ModuleList()
+        self.convs = nn.ModuleList()
+        for in_channel in in_channels:
+            self.lateral_convs.append(
+                ConvModule(
+                    in_channel,
+                    out_channels,
+                    kernel_size=1,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+        for _ in range(self.num_outs):
+            self.convs.append(
+                ConvModule(
+                    out_channels,
+                    out_channels,
+                    kernel_size=3,
+                    padding=1,
+                    stride=1,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+
+    def forward(self, inputs):
+        assert len(inputs) == len(self.in_channels)
+        print(inputs[0].shape)
+        inputs = [
+            lateral_conv(inputs[i])
+            for i, lateral_conv in enumerate(self.lateral_convs)
+        ]
+        # for len(inputs) not equal to self.num_outs
+        if len(inputs) == 1:
+            inputs = [inputs[0] for _ in range(self.num_outs)]
+        outs = []
+        for i in range(self.num_outs):
+            x_resize = F.interpolate(
+                inputs[i], scale_factor=self.scales[i], mode='bilinear')
+            outs.append(self.convs[i](x_resize))
+        return tuple(outs)
diff --git a/annotator/uniformer/mmseg/models/segmentors/__init__.py b/annotator/uniformer/mmseg/models/segmentors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca2f09405330743c476e190896bee39c45498ea
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/segmentors/__init__.py
@@ -0,0 +1,5 @@
+from .base import BaseSegmentor
+from .cascade_encoder_decoder import CascadeEncoderDecoder
+from .encoder_decoder import EncoderDecoder
+
+__all__ = ['BaseSegmentor', 'EncoderDecoder', 'CascadeEncoderDecoder']
diff --git a/annotator/uniformer/mmseg/models/segmentors/base.py b/annotator/uniformer/mmseg/models/segmentors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..172fc63b736c4f13be1cd909433bc260760a1eaa
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/segmentors/base.py
@@ -0,0 +1,273 @@
+import logging
+import warnings
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import annotator.uniformer.mmcv as mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from annotator.uniformer.mmcv.runner import auto_fp16
+
+
+class BaseSegmentor(nn.Module):
+    """Base class for segmentors."""
+
+    __metaclass__ = ABCMeta
+
+    def __init__(self):
+        super(BaseSegmentor, self).__init__()
+        self.fp16_enabled = False
+
+    @property
+    def with_neck(self):
+        """bool: whether the segmentor has neck"""
+        return hasattr(self, 'neck') and self.neck is not None
+
+    @property
+    def with_auxiliary_head(self):
+        """bool: whether the segmentor has auxiliary head"""
+        return hasattr(self,
+                       'auxiliary_head') and self.auxiliary_head is not None
+
+    @property
+    def with_decode_head(self):
+        """bool: whether the segmentor has decode head"""
+        return hasattr(self, 'decode_head') and self.decode_head is not None
+
+    @abstractmethod
+    def extract_feat(self, imgs):
+        """Placeholder for extract features from images."""
+        pass
+
+    @abstractmethod
+    def encode_decode(self, img, img_metas):
+        """Placeholder for encode images with backbone and decode into a
+        semantic segmentation map of the same size as input."""
+        pass
+
+    @abstractmethod
+    def forward_train(self, imgs, img_metas, **kwargs):
+        """Placeholder for Forward function for training."""
+        pass
+
+    @abstractmethod
+    def simple_test(self, img, img_meta, **kwargs):
+        """Placeholder for single image test."""
+        pass
+
+    @abstractmethod
+    def aug_test(self, imgs, img_metas, **kwargs):
+        """Placeholder for augmentation test."""
+        pass
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in segmentor.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        if pretrained is not None:
+            logger = logging.getLogger()
+            logger.info(f'load model from: {pretrained}')
+
+    def forward_test(self, imgs, img_metas, **kwargs):
+        """
+        Args:
+            imgs (List[Tensor]): the outer list indicates test-time
+                augmentations and inner Tensor should have a shape NxCxHxW,
+                which contains all images in the batch.
+            img_metas (List[List[dict]]): the outer list indicates test-time
+                augs (multiscale, flip, etc.) and the inner list indicates
+                images in a batch.
+        """
+        for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+            if not isinstance(var, list):
+                raise TypeError(f'{name} must be a list, but got '
+                                f'{type(var)}')
+
+        num_augs = len(imgs)
+        if num_augs != len(img_metas):
+            raise ValueError(f'num of augmentations ({len(imgs)}) != '
+                             f'num of image meta ({len(img_metas)})')
+        # all images in the same aug batch all of the same ori_shape and pad
+        # shape
+        for img_meta in img_metas:
+            ori_shapes = [_['ori_shape'] for _ in img_meta]
+            assert all(shape == ori_shapes[0] for shape in ori_shapes)
+            img_shapes = [_['img_shape'] for _ in img_meta]
+            assert all(shape == img_shapes[0] for shape in img_shapes)
+            pad_shapes = [_['pad_shape'] for _ in img_meta]
+            assert all(shape == pad_shapes[0] for shape in pad_shapes)
+
+        if num_augs == 1:
+            return self.simple_test(imgs[0], img_metas[0], **kwargs)
+        else:
+            return self.aug_test(imgs, img_metas, **kwargs)
+
+    @auto_fp16(apply_to=('img', ))
+    def forward(self, img, img_metas, return_loss=True, **kwargs):
+        """Calls either :func:`forward_train` or :func:`forward_test` depending
+        on whether ``return_loss`` is ``True``.
+
+        Note this setting will change the expected inputs. When
+        ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+        and List[dict]), and when ``resturn_loss=False``, img and img_meta
+        should be double nested (i.e.  List[Tensor], List[List[dict]]), with
+        the outer list indicating test time augmentations.
+        """
+        if return_loss:
+            return self.forward_train(img, img_metas, **kwargs)
+        else:
+            return self.forward_test(img, img_metas, **kwargs)
+
+    def train_step(self, data_batch, optimizer, **kwargs):
+        """The iteration step during training.
+
+        This method defines an iteration step during training, except for the
+        back propagation and optimizer updating, which are done in an optimizer
+        hook. Note that in some complicated cases or models, the whole process
+        including back propagation and optimizer updating is also defined in
+        this method, such as GAN.
+
+        Args:
+            data (dict): The output of dataloader.
+            optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+                runner is passed to ``train_step()``. This argument is unused
+                and reserved.
+
+        Returns:
+            dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
+                ``num_samples``.
+                ``loss`` is a tensor for back propagation, which can be a
+                weighted sum of multiple losses.
+                ``log_vars`` contains all the variables to be sent to the
+                logger.
+                ``num_samples`` indicates the batch size (when the model is
+                DDP, it means the batch size on each GPU), which is used for
+                averaging the logs.
+        """
+        losses = self(**data_batch)
+        loss, log_vars = self._parse_losses(losses)
+
+        outputs = dict(
+            loss=loss,
+            log_vars=log_vars,
+            num_samples=len(data_batch['img_metas']))
+
+        return outputs
+
+    def val_step(self, data_batch, **kwargs):
+        """The iteration step during validation.
+
+        This method shares the same signature as :func:`train_step`, but used
+        during val epochs. Note that the evaluation after training epochs is
+        not implemented with this method, but an evaluation hook.
+        """
+        output = self(**data_batch, **kwargs)
+        return output
+
+    @staticmethod
+    def _parse_losses(losses):
+        """Parse the raw outputs (losses) of the network.
+
+        Args:
+            losses (dict): Raw output of the network, which usually contain
+                losses and other necessary information.
+
+        Returns:
+            tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
+                which may be a weighted sum of all losses, log_vars contains
+                all the variables to be sent to the logger.
+        """
+        log_vars = OrderedDict()
+        for loss_name, loss_value in losses.items():
+            if isinstance(loss_value, torch.Tensor):
+                log_vars[loss_name] = loss_value.mean()
+            elif isinstance(loss_value, list):
+                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+            else:
+                raise TypeError(
+                    f'{loss_name} is not a tensor or list of tensors')
+
+        loss = sum(_value for _key, _value in log_vars.items()
+                   if 'loss' in _key)
+
+        log_vars['loss'] = loss
+        for loss_name, loss_value in log_vars.items():
+            # reduce loss when distributed training
+            if dist.is_available() and dist.is_initialized():
+                loss_value = loss_value.data.clone()
+                dist.all_reduce(loss_value.div_(dist.get_world_size()))
+            log_vars[loss_name] = loss_value.item()
+
+        return loss, log_vars
+
+    def show_result(self,
+                    img,
+                    result,
+                    palette=None,
+                    win_name='',
+                    show=False,
+                    wait_time=0,
+                    out_file=None,
+                    opacity=0.5):
+        """Draw `result` over `img`.
+
+        Args:
+            img (str or Tensor): The image to be displayed.
+            result (Tensor): The semantic segmentation results to draw over
+                `img`.
+            palette (list[list[int]]] | np.ndarray | None): The palette of
+                segmentation map. If None is given, random palette will be
+                generated. Default: None
+            win_name (str): The window name.
+            wait_time (int): Value of waitKey param.
+                Default: 0.
+            show (bool): Whether to show the image.
+                Default: False.
+            out_file (str or None): The filename to write the image.
+                Default: None.
+            opacity(float): Opacity of painted segmentation map.
+                Default 0.5.
+                Must be in (0, 1] range.
+        Returns:
+            img (Tensor): Only if not `show` or `out_file`
+        """
+        img = mmcv.imread(img)
+        img = img.copy()
+        seg = result[0]
+        if palette is None:
+            if self.PALETTE is None:
+                palette = np.random.randint(
+                    0, 255, size=(len(self.CLASSES), 3))
+            else:
+                palette = self.PALETTE
+        palette = np.array(palette)
+        assert palette.shape[0] == len(self.CLASSES)
+        assert palette.shape[1] == 3
+        assert len(palette.shape) == 2
+        assert 0 < opacity <= 1.0
+        color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
+        for label, color in enumerate(palette):
+            color_seg[seg == label, :] = color
+        # convert to BGR
+        color_seg = color_seg[..., ::-1]
+
+        img = img * (1 - opacity) + color_seg * opacity
+        img = img.astype(np.uint8)
+        # if out_file specified, do not show image in window
+        if out_file is not None:
+            show = False
+
+        if show:
+            mmcv.imshow(img, win_name, wait_time)
+        if out_file is not None:
+            mmcv.imwrite(img, out_file)
+
+        if not (show or out_file):
+            warnings.warn('show==False and out_file is not specified, only '
+                          'result image will be returned')
+            return img
diff --git a/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py b/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..873957d8d6468147c994493d92ff5c1b15bfb703
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/segmentors/cascade_encoder_decoder.py
@@ -0,0 +1,98 @@
+from torch import nn
+
+from annotator.uniformer.mmseg.core import add_prefix
+from annotator.uniformer.mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .encoder_decoder import EncoderDecoder
+
+
+@SEGMENTORS.register_module()
+class CascadeEncoderDecoder(EncoderDecoder):
+    """Cascade Encoder Decoder segmentors.
+
+    CascadeEncoderDecoder almost the same as EncoderDecoder, while decoders of
+    CascadeEncoderDecoder are cascaded. The output of previous decoder_head
+    will be the input of next decoder_head.
+    """
+
+    def __init__(self,
+                 num_stages,
+                 backbone,
+                 decode_head,
+                 neck=None,
+                 auxiliary_head=None,
+                 train_cfg=None,
+                 test_cfg=None,
+                 pretrained=None):
+        self.num_stages = num_stages
+        super(CascadeEncoderDecoder, self).__init__(
+            backbone=backbone,
+            decode_head=decode_head,
+            neck=neck,
+            auxiliary_head=auxiliary_head,
+            train_cfg=train_cfg,
+            test_cfg=test_cfg,
+            pretrained=pretrained)
+
+    def _init_decode_head(self, decode_head):
+        """Initialize ``decode_head``"""
+        assert isinstance(decode_head, list)
+        assert len(decode_head) == self.num_stages
+        self.decode_head = nn.ModuleList()
+        for i in range(self.num_stages):
+            self.decode_head.append(builder.build_head(decode_head[i]))
+        self.align_corners = self.decode_head[-1].align_corners
+        self.num_classes = self.decode_head[-1].num_classes
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone and heads.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+        self.backbone.init_weights(pretrained=pretrained)
+        for i in range(self.num_stages):
+            self.decode_head[i].init_weights()
+        if self.with_auxiliary_head:
+            if isinstance(self.auxiliary_head, nn.ModuleList):
+                for aux_head in self.auxiliary_head:
+                    aux_head.init_weights()
+            else:
+                self.auxiliary_head.init_weights()
+
+    def encode_decode(self, img, img_metas):
+        """Encode images with backbone and decode into a semantic segmentation
+        map of the same size as input."""
+        x = self.extract_feat(img)
+        out = self.decode_head[0].forward_test(x, img_metas, self.test_cfg)
+        for i in range(1, self.num_stages):
+            out = self.decode_head[i].forward_test(x, out, img_metas,
+                                                   self.test_cfg)
+        out = resize(
+            input=out,
+            size=img.shape[2:],
+            mode='bilinear',
+            align_corners=self.align_corners)
+        return out
+
+    def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+        """Run forward function and calculate loss for decode head in
+        training."""
+        losses = dict()
+
+        loss_decode = self.decode_head[0].forward_train(
+            x, img_metas, gt_semantic_seg, self.train_cfg)
+
+        losses.update(add_prefix(loss_decode, 'decode_0'))
+
+        for i in range(1, self.num_stages):
+            # forward test again, maybe unnecessary for most methods.
+            prev_outputs = self.decode_head[i - 1].forward_test(
+                x, img_metas, self.test_cfg)
+            loss_decode = self.decode_head[i].forward_train(
+                x, prev_outputs, img_metas, gt_semantic_seg, self.train_cfg)
+            losses.update(add_prefix(loss_decode, f'decode_{i}'))
+
+        return losses
diff --git a/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py b/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..98392ac04c4c44a7f4e7b1c0808266875877dd1f
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/segmentors/encoder_decoder.py
@@ -0,0 +1,298 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from annotator.uniformer.mmseg.core import add_prefix
+from annotator.uniformer.mmseg.ops import resize
+from .. import builder
+from ..builder import SEGMENTORS
+from .base import BaseSegmentor
+
+
+@SEGMENTORS.register_module()
+class EncoderDecoder(BaseSegmentor):
+    """Encoder Decoder segmentors.
+
+    EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
+    Note that auxiliary_head is only used for deep supervision during training,
+    which could be dumped during inference.
+    """
+
+    def __init__(self,
+                 backbone,
+                 decode_head,
+                 neck=None,
+                 auxiliary_head=None,
+                 train_cfg=None,
+                 test_cfg=None,
+                 pretrained=None):
+        super(EncoderDecoder, self).__init__()
+        self.backbone = builder.build_backbone(backbone)
+        if neck is not None:
+            self.neck = builder.build_neck(neck)
+        self._init_decode_head(decode_head)
+        self._init_auxiliary_head(auxiliary_head)
+
+        self.train_cfg = train_cfg
+        self.test_cfg = test_cfg
+
+        self.init_weights(pretrained=pretrained)
+
+        assert self.with_decode_head
+
+    def _init_decode_head(self, decode_head):
+        """Initialize ``decode_head``"""
+        self.decode_head = builder.build_head(decode_head)
+        self.align_corners = self.decode_head.align_corners
+        self.num_classes = self.decode_head.num_classes
+
+    def _init_auxiliary_head(self, auxiliary_head):
+        """Initialize ``auxiliary_head``"""
+        if auxiliary_head is not None:
+            if isinstance(auxiliary_head, list):
+                self.auxiliary_head = nn.ModuleList()
+                for head_cfg in auxiliary_head:
+                    self.auxiliary_head.append(builder.build_head(head_cfg))
+            else:
+                self.auxiliary_head = builder.build_head(auxiliary_head)
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone and heads.
+
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+
+        super(EncoderDecoder, self).init_weights(pretrained)
+        self.backbone.init_weights(pretrained=pretrained)
+        self.decode_head.init_weights()
+        if self.with_auxiliary_head:
+            if isinstance(self.auxiliary_head, nn.ModuleList):
+                for aux_head in self.auxiliary_head:
+                    aux_head.init_weights()
+            else:
+                self.auxiliary_head.init_weights()
+
+    def extract_feat(self, img):
+        """Extract features from images."""
+        x = self.backbone(img)
+        if self.with_neck:
+            x = self.neck(x)
+        return x
+
+    def encode_decode(self, img, img_metas):
+        """Encode images with backbone and decode into a semantic segmentation
+        map of the same size as input."""
+        x = self.extract_feat(img)
+        out = self._decode_head_forward_test(x, img_metas)
+        out = resize(
+            input=out,
+            size=img.shape[2:],
+            mode='bilinear',
+            align_corners=self.align_corners)
+        return out
+
+    def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg):
+        """Run forward function and calculate loss for decode head in
+        training."""
+        losses = dict()
+        loss_decode = self.decode_head.forward_train(x, img_metas,
+                                                     gt_semantic_seg,
+                                                     self.train_cfg)
+
+        losses.update(add_prefix(loss_decode, 'decode'))
+        return losses
+
+    def _decode_head_forward_test(self, x, img_metas):
+        """Run forward function and calculate loss for decode head in
+        inference."""
+        seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
+        return seg_logits
+
+    def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
+        """Run forward function and calculate loss for auxiliary head in
+        training."""
+        losses = dict()
+        if isinstance(self.auxiliary_head, nn.ModuleList):
+            for idx, aux_head in enumerate(self.auxiliary_head):
+                loss_aux = aux_head.forward_train(x, img_metas,
+                                                  gt_semantic_seg,
+                                                  self.train_cfg)
+                losses.update(add_prefix(loss_aux, f'aux_{idx}'))
+        else:
+            loss_aux = self.auxiliary_head.forward_train(
+                x, img_metas, gt_semantic_seg, self.train_cfg)
+            losses.update(add_prefix(loss_aux, 'aux'))
+
+        return losses
+
+    def forward_dummy(self, img):
+        """Dummy forward function."""
+        seg_logit = self.encode_decode(img, None)
+
+        return seg_logit
+
+    def forward_train(self, img, img_metas, gt_semantic_seg):
+        """Forward function for training.
+
+        Args:
+            img (Tensor): Input images.
+            img_metas (list[dict]): List of image info dict where each dict
+                has: 'img_shape', 'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmseg/datasets/pipelines/formatting.py:Collect`.
+            gt_semantic_seg (Tensor): Semantic segmentation masks
+                used if the architecture supports semantic segmentation task.
+
+        Returns:
+            dict[str, Tensor]: a dictionary of loss components
+        """
+
+        x = self.extract_feat(img)
+
+        losses = dict()
+
+        loss_decode = self._decode_head_forward_train(x, img_metas,
+                                                      gt_semantic_seg)
+        losses.update(loss_decode)
+
+        if self.with_auxiliary_head:
+            loss_aux = self._auxiliary_head_forward_train(
+                x, img_metas, gt_semantic_seg)
+            losses.update(loss_aux)
+
+        return losses
+
+    # TODO refactor
+    def slide_inference(self, img, img_meta, rescale):
+        """Inference by sliding-window with overlap.
+
+        If h_crop > h_img or w_crop > w_img, the small patch will be used to
+        decode without padding.
+        """
+
+        h_stride, w_stride = self.test_cfg.stride
+        h_crop, w_crop = self.test_cfg.crop_size
+        batch_size, _, h_img, w_img = img.size()
+        num_classes = self.num_classes
+        h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
+        w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
+        preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
+        count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
+        for h_idx in range(h_grids):
+            for w_idx in range(w_grids):
+                y1 = h_idx * h_stride
+                x1 = w_idx * w_stride
+                y2 = min(y1 + h_crop, h_img)
+                x2 = min(x1 + w_crop, w_img)
+                y1 = max(y2 - h_crop, 0)
+                x1 = max(x2 - w_crop, 0)
+                crop_img = img[:, :, y1:y2, x1:x2]
+                crop_seg_logit = self.encode_decode(crop_img, img_meta)
+                preds += F.pad(crop_seg_logit,
+                               (int(x1), int(preds.shape[3] - x2), int(y1),
+                                int(preds.shape[2] - y2)))
+
+                count_mat[:, :, y1:y2, x1:x2] += 1
+        assert (count_mat == 0).sum() == 0
+        if torch.onnx.is_in_onnx_export():
+            # cast count_mat to constant while exporting to ONNX
+            count_mat = torch.from_numpy(
+                count_mat.cpu().detach().numpy()).to(device=img.device)
+        preds = preds / count_mat
+        if rescale:
+            preds = resize(
+                preds,
+                size=img_meta[0]['ori_shape'][:2],
+                mode='bilinear',
+                align_corners=self.align_corners,
+                warning=False)
+        return preds
+
+    def whole_inference(self, img, img_meta, rescale):
+        """Inference with full image."""
+
+        seg_logit = self.encode_decode(img, img_meta)
+        if rescale:
+            # support dynamic shape for onnx
+            if torch.onnx.is_in_onnx_export():
+                size = img.shape[2:]
+            else:
+                size = img_meta[0]['ori_shape'][:2]
+            seg_logit = resize(
+                seg_logit,
+                size=size,
+                mode='bilinear',
+                align_corners=self.align_corners,
+                warning=False)
+
+        return seg_logit
+
+    def inference(self, img, img_meta, rescale):
+        """Inference with slide/whole style.
+
+        Args:
+            img (Tensor): The input image of shape (N, 3, H, W).
+            img_meta (dict): Image info dict where each dict has: 'img_shape',
+                'scale_factor', 'flip', and may also contain
+                'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+                For details on the values of these keys see
+                `mmseg/datasets/pipelines/formatting.py:Collect`.
+            rescale (bool): Whether rescale back to original shape.
+
+        Returns:
+            Tensor: The output segmentation map.
+        """
+
+        assert self.test_cfg.mode in ['slide', 'whole']
+        ori_shape = img_meta[0]['ori_shape']
+        assert all(_['ori_shape'] == ori_shape for _ in img_meta)
+        if self.test_cfg.mode == 'slide':
+            seg_logit = self.slide_inference(img, img_meta, rescale)
+        else:
+            seg_logit = self.whole_inference(img, img_meta, rescale)
+        output = F.softmax(seg_logit, dim=1)
+        flip = img_meta[0]['flip']
+        if flip:
+            flip_direction = img_meta[0]['flip_direction']
+            assert flip_direction in ['horizontal', 'vertical']
+            if flip_direction == 'horizontal':
+                output = output.flip(dims=(3, ))
+            elif flip_direction == 'vertical':
+                output = output.flip(dims=(2, ))
+
+        return output
+
+    def simple_test(self, img, img_meta, rescale=True):
+        """Simple test with single image."""
+        seg_logit = self.inference(img, img_meta, rescale)
+        seg_pred = seg_logit.argmax(dim=1)
+        if torch.onnx.is_in_onnx_export():
+            # our inference backend only support 4D output
+            seg_pred = seg_pred.unsqueeze(0)
+            return seg_pred
+        seg_pred = seg_pred.cpu().numpy()
+        # unravel batch dim
+        seg_pred = list(seg_pred)
+        return seg_pred
+
+    def aug_test(self, imgs, img_metas, rescale=True):
+        """Test with augmentations.
+
+        Only rescale=True is supported.
+        """
+        # aug_test rescale all imgs back to ori_shape for now
+        assert rescale
+        # to save memory, we get augmented seg logit inplace
+        seg_logit = self.inference(imgs[0], img_metas[0], rescale)
+        for i in range(1, len(imgs)):
+            cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
+            seg_logit += cur_seg_logit
+        seg_logit /= len(imgs)
+        seg_pred = seg_logit.argmax(dim=1)
+        seg_pred = seg_pred.cpu().numpy()
+        # unravel batch dim
+        seg_pred = list(seg_pred)
+        return seg_pred
diff --git a/annotator/uniformer/mmseg/models/utils/__init__.py b/annotator/uniformer/mmseg/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d3bdd349b9f2ae499a2fcb2ac1d2e3c77befebe
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/__init__.py
@@ -0,0 +1,13 @@
+from .drop import DropPath
+from .inverted_residual import InvertedResidual, InvertedResidualV3
+from .make_divisible import make_divisible
+from .res_layer import ResLayer
+from .se_layer import SELayer
+from .self_attention_block import SelfAttentionBlock
+from .up_conv_block import UpConvBlock
+from .weight_init import trunc_normal_
+
+__all__ = [
+    'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
+    'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'DropPath', 'trunc_normal_'
+]
diff --git a/annotator/uniformer/mmseg/models/utils/drop.py b/annotator/uniformer/mmseg/models/utils/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..4520b0ff407d2a95a864086bdbca0065f222aa63
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/drop.py
@@ -0,0 +1,31 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/layers/drop.py."""
+
+import torch
+from torch import nn
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample  (when applied in main path of
+    residual blocks).
+
+    Args:
+        drop_prob (float): Drop rate for paths of model. Dropout rate has
+            to be between 0 and 1. Default: 0.
+    """
+
+    def __init__(self, drop_prob=0.):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+        self.keep_prob = 1 - drop_prob
+
+    def forward(self, x):
+        if self.drop_prob == 0. or not self.training:
+            return x
+        shape = (x.shape[0], ) + (1, ) * (
+            x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+        random_tensor = self.keep_prob + torch.rand(
+            shape, dtype=x.dtype, device=x.device)
+        random_tensor.floor_()  # binarize
+        output = x.div(self.keep_prob) * random_tensor
+        return output
diff --git a/annotator/uniformer/mmseg/models/utils/inverted_residual.py b/annotator/uniformer/mmseg/models/utils/inverted_residual.py
new file mode 100644
index 0000000000000000000000000000000000000000..53b8fcd41f71d814738f1ac3f5acd3c3d701bf96
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/inverted_residual.py
@@ -0,0 +1,208 @@
+from annotator.uniformer.mmcv.cnn import ConvModule
+from torch import nn
+from torch.utils import checkpoint as cp
+
+from .se_layer import SELayer
+
+
+class InvertedResidual(nn.Module):
+    """InvertedResidual block for MobileNetV2.
+
+    Args:
+        in_channels (int): The input channels of the InvertedResidual block.
+        out_channels (int): The output channels of the InvertedResidual block.
+        stride (int): Stride of the middle (first) 3x3 convolution.
+        expand_ratio (int): Adjusts number of channels of the hidden layer
+            in InvertedResidual by this amount.
+        dilation (int): Dilation rate of depthwise conv. Default: 1
+        conv_cfg (dict): Config dict for convolution layer.
+            Default: None, which means using conv2d.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='BN').
+        act_cfg (dict): Config dict for activation layer.
+            Default: dict(type='ReLU6').
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+
+    Returns:
+        Tensor: The output tensor.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 stride,
+                 expand_ratio,
+                 dilation=1,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU6'),
+                 with_cp=False):
+        super(InvertedResidual, self).__init__()
+        self.stride = stride
+        assert stride in [1, 2], f'stride must in [1, 2]. ' \
+            f'But received {stride}.'
+        self.with_cp = with_cp
+        self.use_res_connect = self.stride == 1 and in_channels == out_channels
+        hidden_dim = int(round(in_channels * expand_ratio))
+
+        layers = []
+        if expand_ratio != 1:
+            layers.append(
+                ConvModule(
+                    in_channels=in_channels,
+                    out_channels=hidden_dim,
+                    kernel_size=1,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg))
+        layers.extend([
+            ConvModule(
+                in_channels=hidden_dim,
+                out_channels=hidden_dim,
+                kernel_size=3,
+                stride=stride,
+                padding=dilation,
+                dilation=dilation,
+                groups=hidden_dim,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg),
+            ConvModule(
+                in_channels=hidden_dim,
+                out_channels=out_channels,
+                kernel_size=1,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=None)
+        ])
+        self.conv = nn.Sequential(*layers)
+
+    def forward(self, x):
+
+        def _inner_forward(x):
+            if self.use_res_connect:
+                return x + self.conv(x)
+            else:
+                return self.conv(x)
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        return out
+
+
+class InvertedResidualV3(nn.Module):
+    """Inverted Residual Block for MobileNetV3.
+
+    Args:
+        in_channels (int): The input channels of this Module.
+        out_channels (int): The output channels of this Module.
+        mid_channels (int): The input channels of the depthwise convolution.
+        kernel_size (int): The kernel size of the depthwise convolution.
+            Default: 3.
+        stride (int): The stride of the depthwise convolution. Default: 1.
+        se_cfg (dict): Config dict for se layer. Default: None, which means no
+            se layer.
+        with_expand_conv (bool): Use expand conv or not. If set False,
+            mid_channels must be the same with in_channels. Default: True.
+        conv_cfg (dict): Config dict for convolution layer. Default: None,
+            which means using conv2d.
+        norm_cfg (dict): Config dict for normalization layer.
+            Default: dict(type='BN').
+        act_cfg (dict): Config dict for activation layer.
+            Default: dict(type='ReLU').
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+
+    Returns:
+        Tensor: The output tensor.
+    """
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 mid_channels,
+                 kernel_size=3,
+                 stride=1,
+                 se_cfg=None,
+                 with_expand_conv=True,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 with_cp=False):
+        super(InvertedResidualV3, self).__init__()
+        self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
+        assert stride in [1, 2]
+        self.with_cp = with_cp
+        self.with_se = se_cfg is not None
+        self.with_expand_conv = with_expand_conv
+
+        if self.with_se:
+            assert isinstance(se_cfg, dict)
+        if not self.with_expand_conv:
+            assert mid_channels == in_channels
+
+        if self.with_expand_conv:
+            self.expand_conv = ConvModule(
+                in_channels=in_channels,
+                out_channels=mid_channels,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg)
+        self.depthwise_conv = ConvModule(
+            in_channels=mid_channels,
+            out_channels=mid_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=kernel_size // 2,
+            groups=mid_channels,
+            conv_cfg=dict(
+                type='Conv2dAdaptivePadding') if stride == 2 else conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg)
+
+        if self.with_se:
+            self.se = SELayer(**se_cfg)
+
+        self.linear_conv = ConvModule(
+            in_channels=mid_channels,
+            out_channels=out_channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=None)
+
+    def forward(self, x):
+
+        def _inner_forward(x):
+            out = x
+
+            if self.with_expand_conv:
+                out = self.expand_conv(out)
+
+            out = self.depthwise_conv(out)
+
+            if self.with_se:
+                out = self.se(out)
+
+            out = self.linear_conv(out)
+
+            if self.with_res_shortcut:
+                return x + out
+            else:
+                return out
+
+        if self.with_cp and x.requires_grad:
+            out = cp.checkpoint(_inner_forward, x)
+        else:
+            out = _inner_forward(x)
+
+        return out
diff --git a/annotator/uniformer/mmseg/models/utils/make_divisible.py b/annotator/uniformer/mmseg/models/utils/make_divisible.py
new file mode 100644
index 0000000000000000000000000000000000000000..75ad756052529f52fe83bb95dd1f0ecfc9a13078
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/make_divisible.py
@@ -0,0 +1,27 @@
+def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
+    """Make divisible function.
+
+    This function rounds the channel number to the nearest value that can be
+    divisible by the divisor. It is taken from the original tf repo. It ensures
+    that all layers have a channel number that is divisible by divisor. It can
+    be seen here: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py  # noqa
+
+    Args:
+        value (int): The original channel number.
+        divisor (int): The divisor to fully divide the channel number.
+        min_value (int): The minimum value of the output channel.
+            Default: None, means that the minimum value equal to the divisor.
+        min_ratio (float): The minimum ratio of the rounded channel number to
+            the original channel number. Default: 0.9.
+
+    Returns:
+        int: The modified output channel number.
+    """
+
+    if min_value is None:
+        min_value = divisor
+    new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
+    # Make sure that round down does not go down by more than (1-min_ratio).
+    if new_value < min_ratio * value:
+        new_value += divisor
+    return new_value
diff --git a/annotator/uniformer/mmseg/models/utils/res_layer.py b/annotator/uniformer/mmseg/models/utils/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2c07b47007e92e4c3945b989e79f9d50306f5fe
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/res_layer.py
@@ -0,0 +1,94 @@
+from annotator.uniformer.mmcv.cnn import build_conv_layer, build_norm_layer
+from torch import nn as nn
+
+
+class ResLayer(nn.Sequential):
+    """ResLayer to build ResNet style backbone.
+
+    Args:
+        block (nn.Module): block used to build ResLayer.
+        inplanes (int): inplanes of block.
+        planes (int): planes of block.
+        num_blocks (int): number of blocks.
+        stride (int): stride of the first block. Default: 1
+        avg_down (bool): Use AvgPool instead of stride conv when
+            downsampling in the bottleneck. Default: False
+        conv_cfg (dict): dictionary to construct and config conv layer.
+            Default: None
+        norm_cfg (dict): dictionary to construct and config norm layer.
+            Default: dict(type='BN')
+        multi_grid (int | None): Multi grid dilation rates of last
+            stage. Default: None
+        contract_dilation (bool): Whether contract first dilation of each layer
+            Default: False
+    """
+
+    def __init__(self,
+                 block,
+                 inplanes,
+                 planes,
+                 num_blocks,
+                 stride=1,
+                 dilation=1,
+                 avg_down=False,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 multi_grid=None,
+                 contract_dilation=False,
+                 **kwargs):
+        self.block = block
+
+        downsample = None
+        if stride != 1 or inplanes != planes * block.expansion:
+            downsample = []
+            conv_stride = stride
+            if avg_down:
+                conv_stride = 1
+                downsample.append(
+                    nn.AvgPool2d(
+                        kernel_size=stride,
+                        stride=stride,
+                        ceil_mode=True,
+                        count_include_pad=False))
+            downsample.extend([
+                build_conv_layer(
+                    conv_cfg,
+                    inplanes,
+                    planes * block.expansion,
+                    kernel_size=1,
+                    stride=conv_stride,
+                    bias=False),
+                build_norm_layer(norm_cfg, planes * block.expansion)[1]
+            ])
+            downsample = nn.Sequential(*downsample)
+
+        layers = []
+        if multi_grid is None:
+            if dilation > 1 and contract_dilation:
+                first_dilation = dilation // 2
+            else:
+                first_dilation = dilation
+        else:
+            first_dilation = multi_grid[0]
+        layers.append(
+            block(
+                inplanes=inplanes,
+                planes=planes,
+                stride=stride,
+                dilation=first_dilation,
+                downsample=downsample,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                **kwargs))
+        inplanes = planes * block.expansion
+        for i in range(1, num_blocks):
+            layers.append(
+                block(
+                    inplanes=inplanes,
+                    planes=planes,
+                    stride=1,
+                    dilation=dilation if multi_grid is None else multi_grid[i],
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    **kwargs))
+        super(ResLayer, self).__init__(*layers)
diff --git a/annotator/uniformer/mmseg/models/utils/se_layer.py b/annotator/uniformer/mmseg/models/utils/se_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..083bd7d1ccee909c900c7aed2cc928bf14727f3e
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/se_layer.py
@@ -0,0 +1,57 @@
+import annotator.uniformer.mmcv as mmcv
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule
+
+from .make_divisible import make_divisible
+
+
+class SELayer(nn.Module):
+    """Squeeze-and-Excitation Module.
+
+    Args:
+        channels (int): The input (and output) channels of the SE layer.
+        ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
+            ``int(channels/ratio)``. Default: 16.
+        conv_cfg (None or dict): Config dict for convolution layer.
+            Default: None, which means using conv2d.
+        act_cfg (dict or Sequence[dict]): Config dict for activation layer.
+            If act_cfg is a dict, two activation layers will be configured
+            by this dict. If act_cfg is a sequence of dicts, the first
+            activation layer will be configured by the first dict and the
+            second activation layer will be configured by the second dict.
+            Default: (dict(type='ReLU'), dict(type='HSigmoid', bias=3.0,
+            divisor=6.0)).
+    """
+
+    def __init__(self,
+                 channels,
+                 ratio=16,
+                 conv_cfg=None,
+                 act_cfg=(dict(type='ReLU'),
+                          dict(type='HSigmoid', bias=3.0, divisor=6.0))):
+        super(SELayer, self).__init__()
+        if isinstance(act_cfg, dict):
+            act_cfg = (act_cfg, act_cfg)
+        assert len(act_cfg) == 2
+        assert mmcv.is_tuple_of(act_cfg, dict)
+        self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+        self.conv1 = ConvModule(
+            in_channels=channels,
+            out_channels=make_divisible(channels // ratio, 8),
+            kernel_size=1,
+            stride=1,
+            conv_cfg=conv_cfg,
+            act_cfg=act_cfg[0])
+        self.conv2 = ConvModule(
+            in_channels=make_divisible(channels // ratio, 8),
+            out_channels=channels,
+            kernel_size=1,
+            stride=1,
+            conv_cfg=conv_cfg,
+            act_cfg=act_cfg[1])
+
+    def forward(self, x):
+        out = self.global_avgpool(x)
+        out = self.conv1(out)
+        out = self.conv2(out)
+        return x * out
diff --git a/annotator/uniformer/mmseg/models/utils/self_attention_block.py b/annotator/uniformer/mmseg/models/utils/self_attention_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..440c7b73ee4706fde555595926d63a18d7574acc
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/self_attention_block.py
@@ -0,0 +1,159 @@
+import torch
+from annotator.uniformer.mmcv.cnn import ConvModule, constant_init
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+class SelfAttentionBlock(nn.Module):
+    """General self-attention block/non-local block.
+
+    Please refer to https://arxiv.org/abs/1706.03762 for details about key,
+    query and value.
+
+    Args:
+        key_in_channels (int): Input channels of key feature.
+        query_in_channels (int): Input channels of query feature.
+        channels (int): Output channels of key/query transform.
+        out_channels (int): Output channels.
+        share_key_query (bool): Whether share projection weight between key
+            and query projection.
+        query_downsample (nn.Module): Query downsample module.
+        key_downsample (nn.Module): Key downsample module.
+        key_query_num_convs (int): Number of convs for key/query projection.
+        value_num_convs (int): Number of convs for value projection.
+        matmul_norm (bool): Whether normalize attention map with sqrt of
+            channels
+        with_out (bool): Whether use out projection.
+        conv_cfg (dict|None): Config of conv layers.
+        norm_cfg (dict|None): Config of norm layers.
+        act_cfg (dict|None): Config of activation layers.
+    """
+
+    def __init__(self, key_in_channels, query_in_channels, channels,
+                 out_channels, share_key_query, query_downsample,
+                 key_downsample, key_query_num_convs, value_out_num_convs,
+                 key_query_norm, value_out_norm, matmul_norm, with_out,
+                 conv_cfg, norm_cfg, act_cfg):
+        super(SelfAttentionBlock, self).__init__()
+        if share_key_query:
+            assert key_in_channels == query_in_channels
+        self.key_in_channels = key_in_channels
+        self.query_in_channels = query_in_channels
+        self.out_channels = out_channels
+        self.channels = channels
+        self.share_key_query = share_key_query
+        self.conv_cfg = conv_cfg
+        self.norm_cfg = norm_cfg
+        self.act_cfg = act_cfg
+        self.key_project = self.build_project(
+            key_in_channels,
+            channels,
+            num_convs=key_query_num_convs,
+            use_conv_module=key_query_norm,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg)
+        if share_key_query:
+            self.query_project = self.key_project
+        else:
+            self.query_project = self.build_project(
+                query_in_channels,
+                channels,
+                num_convs=key_query_num_convs,
+                use_conv_module=key_query_norm,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg)
+        self.value_project = self.build_project(
+            key_in_channels,
+            channels if with_out else out_channels,
+            num_convs=value_out_num_convs,
+            use_conv_module=value_out_norm,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg)
+        if with_out:
+            self.out_project = self.build_project(
+                channels,
+                out_channels,
+                num_convs=value_out_num_convs,
+                use_conv_module=value_out_norm,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg)
+        else:
+            self.out_project = None
+
+        self.query_downsample = query_downsample
+        self.key_downsample = key_downsample
+        self.matmul_norm = matmul_norm
+
+        self.init_weights()
+
+    def init_weights(self):
+        """Initialize weight of later layer."""
+        if self.out_project is not None:
+            if not isinstance(self.out_project, ConvModule):
+                constant_init(self.out_project, 0)
+
+    def build_project(self, in_channels, channels, num_convs, use_conv_module,
+                      conv_cfg, norm_cfg, act_cfg):
+        """Build projection layer for key/query/value/out."""
+        if use_conv_module:
+            convs = [
+                ConvModule(
+                    in_channels,
+                    channels,
+                    1,
+                    conv_cfg=conv_cfg,
+                    norm_cfg=norm_cfg,
+                    act_cfg=act_cfg)
+            ]
+            for _ in range(num_convs - 1):
+                convs.append(
+                    ConvModule(
+                        channels,
+                        channels,
+                        1,
+                        conv_cfg=conv_cfg,
+                        norm_cfg=norm_cfg,
+                        act_cfg=act_cfg))
+        else:
+            convs = [nn.Conv2d(in_channels, channels, 1)]
+            for _ in range(num_convs - 1):
+                convs.append(nn.Conv2d(channels, channels, 1))
+        if len(convs) > 1:
+            convs = nn.Sequential(*convs)
+        else:
+            convs = convs[0]
+        return convs
+
+    def forward(self, query_feats, key_feats):
+        """Forward function."""
+        batch_size = query_feats.size(0)
+        query = self.query_project(query_feats)
+        if self.query_downsample is not None:
+            query = self.query_downsample(query)
+        query = query.reshape(*query.shape[:2], -1)
+        query = query.permute(0, 2, 1).contiguous()
+
+        key = self.key_project(key_feats)
+        value = self.value_project(key_feats)
+        if self.key_downsample is not None:
+            key = self.key_downsample(key)
+            value = self.key_downsample(value)
+        key = key.reshape(*key.shape[:2], -1)
+        value = value.reshape(*value.shape[:2], -1)
+        value = value.permute(0, 2, 1).contiguous()
+
+        sim_map = torch.matmul(query, key)
+        if self.matmul_norm:
+            sim_map = (self.channels**-.5) * sim_map
+        sim_map = F.softmax(sim_map, dim=-1)
+
+        context = torch.matmul(sim_map, value)
+        context = context.permute(0, 2, 1).contiguous()
+        context = context.reshape(batch_size, -1, *query_feats.shape[2:])
+        if self.out_project is not None:
+            context = self.out_project(context)
+        return context
diff --git a/annotator/uniformer/mmseg/models/utils/up_conv_block.py b/annotator/uniformer/mmseg/models/utils/up_conv_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..378469da76cb7bff6a639e7877b3c275d50490fb
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/up_conv_block.py
@@ -0,0 +1,101 @@
+import torch
+import torch.nn as nn
+from annotator.uniformer.mmcv.cnn import ConvModule, build_upsample_layer
+
+
+class UpConvBlock(nn.Module):
+    """Upsample convolution block in decoder for UNet.
+
+    This upsample convolution block consists of one upsample module
+    followed by one convolution block. The upsample module expands the
+    high-level low-resolution feature map and the convolution block fuses
+    the upsampled high-level low-resolution feature map and the low-level
+    high-resolution feature map from encoder.
+
+    Args:
+        conv_block (nn.Sequential): Sequential of convolutional layers.
+        in_channels (int): Number of input channels of the high-level
+        skip_channels (int): Number of input channels of the low-level
+        high-resolution feature map from encoder.
+        out_channels (int): Number of output channels.
+        num_convs (int): Number of convolutional layers in the conv_block.
+            Default: 2.
+        stride (int): Stride of convolutional layer in conv_block. Default: 1.
+        dilation (int): Dilation rate of convolutional layer in conv_block.
+            Default: 1.
+        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+            memory while slowing down the training speed. Default: False.
+        conv_cfg (dict | None): Config dict for convolution layer.
+            Default: None.
+        norm_cfg (dict | None): Config dict for normalization layer.
+            Default: dict(type='BN').
+        act_cfg (dict | None): Config dict for activation layer in ConvModule.
+            Default: dict(type='ReLU').
+        upsample_cfg (dict): The upsample config of the upsample module in
+            decoder. Default: dict(type='InterpConv'). If the size of
+            high-level feature map is the same as that of skip feature map
+            (low-level feature map from encoder), it does not need upsample the
+            high-level feature map and the upsample_cfg is None.
+        dcn (bool): Use deformable convolution in convolutional layer or not.
+            Default: None.
+        plugins (dict): plugins for convolutional layers. Default: None.
+    """
+
+    def __init__(self,
+                 conv_block,
+                 in_channels,
+                 skip_channels,
+                 out_channels,
+                 num_convs=2,
+                 stride=1,
+                 dilation=1,
+                 with_cp=False,
+                 conv_cfg=None,
+                 norm_cfg=dict(type='BN'),
+                 act_cfg=dict(type='ReLU'),
+                 upsample_cfg=dict(type='InterpConv'),
+                 dcn=None,
+                 plugins=None):
+        super(UpConvBlock, self).__init__()
+        assert dcn is None, 'Not implemented yet.'
+        assert plugins is None, 'Not implemented yet.'
+
+        self.conv_block = conv_block(
+            in_channels=2 * skip_channels,
+            out_channels=out_channels,
+            num_convs=num_convs,
+            stride=stride,
+            dilation=dilation,
+            with_cp=with_cp,
+            conv_cfg=conv_cfg,
+            norm_cfg=norm_cfg,
+            act_cfg=act_cfg,
+            dcn=None,
+            plugins=None)
+        if upsample_cfg is not None:
+            self.upsample = build_upsample_layer(
+                cfg=upsample_cfg,
+                in_channels=in_channels,
+                out_channels=skip_channels,
+                with_cp=with_cp,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg)
+        else:
+            self.upsample = ConvModule(
+                in_channels,
+                skip_channels,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                conv_cfg=conv_cfg,
+                norm_cfg=norm_cfg,
+                act_cfg=act_cfg)
+
+    def forward(self, skip, x):
+        """Forward function."""
+
+        x = self.upsample(x)
+        out = torch.cat([skip, x], dim=1)
+        out = self.conv_block(out)
+
+        return out
diff --git a/annotator/uniformer/mmseg/models/utils/weight_init.py b/annotator/uniformer/mmseg/models/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..38141ba3d61f64ddfc0a31574b4648cbad96d7dd
--- /dev/null
+++ b/annotator/uniformer/mmseg/models/utils/weight_init.py
@@ -0,0 +1,62 @@
+"""Modified from https://github.com/rwightman/pytorch-image-
+models/blob/master/timm/models/layers/drop.py."""
+
+import math
+import warnings
+
+import torch
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+    """Reference: https://people.sc.fsu.edu/~jburkardt/presentations
+    /truncated_normal.pdf"""
+
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn(
+            'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+            'The distribution of values may be incorrect.',
+            stacklevel=2)
+
+    with torch.no_grad():
+        # Values are generated by using a truncated uniform distribution and
+        # then using the inverse CDF for the normal distribution.
+        # Get upper and lower cdf values
+        lower_bound = norm_cdf((a - mean) / std)
+        upper_bound = norm_cdf((b - mean) / std)
+
+        # Uniformly fill tensor with values from [l, u], then translate to
+        # [2l-1, 2u-1].
+        tensor.uniform_(2 * lower_bound - 1, 2 * upper_bound - 1)
+
+        # Use inverse cdf transform for normal distribution to get truncated
+        # standard normal
+        tensor.erfinv_()
+
+        # Transform to proper mean, std
+        tensor.mul_(std * math.sqrt(2.))
+        tensor.add_(mean)
+
+        # Clamp to ensure it's in the proper range
+        tensor.clamp_(min=a, max=b)
+        return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+    r"""Fills the input Tensor with values drawn from a truncated
+    normal distribution. The values are effectively drawn from the
+    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+    with values outside :math:`[a, b]` redrawn until they are within
+    the bounds. The method used for generating the random values works
+    best when :math:`a \leq \text{mean} \leq b`.
+    Args:
+        tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`
+        mean (float): the mean of the normal distribution
+        std (float): the standard deviation of the normal distribution
+        a (float): the minimum cutoff value
+        b (float): the maximum cutoff value
+    """
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/annotator/uniformer/mmseg/ops/__init__.py b/annotator/uniformer/mmseg/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bec51c75b9363a9a19e9fb5c35f4e7dbd6f7751c
--- /dev/null
+++ b/annotator/uniformer/mmseg/ops/__init__.py
@@ -0,0 +1,4 @@
+from .encoding import Encoding
+from .wrappers import Upsample, resize
+
+__all__ = ['Upsample', 'resize', 'Encoding']
diff --git a/annotator/uniformer/mmseg/ops/encoding.py b/annotator/uniformer/mmseg/ops/encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb3629a6426550b8e4c537ee1ff4341893e489e
--- /dev/null
+++ b/annotator/uniformer/mmseg/ops/encoding.py
@@ -0,0 +1,74 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+class Encoding(nn.Module):
+    """Encoding Layer: a learnable residual encoder.
+
+    Input is of shape  (batch_size, channels, height, width).
+    Output is of shape (batch_size, num_codes, channels).
+
+    Args:
+        channels: dimension of the features or feature channels
+        num_codes: number of code words
+    """
+
+    def __init__(self, channels, num_codes):
+        super(Encoding, self).__init__()
+        # init codewords and smoothing factor
+        self.channels, self.num_codes = channels, num_codes
+        std = 1. / ((num_codes * channels)**0.5)
+        # [num_codes, channels]
+        self.codewords = nn.Parameter(
+            torch.empty(num_codes, channels,
+                        dtype=torch.float).uniform_(-std, std),
+            requires_grad=True)
+        # [num_codes]
+        self.scale = nn.Parameter(
+            torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0),
+            requires_grad=True)
+
+    @staticmethod
+    def scaled_l2(x, codewords, scale):
+        num_codes, channels = codewords.size()
+        batch_size = x.size(0)
+        reshaped_scale = scale.view((1, 1, num_codes))
+        expanded_x = x.unsqueeze(2).expand(
+            (batch_size, x.size(1), num_codes, channels))
+        reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+
+        scaled_l2_norm = reshaped_scale * (
+            expanded_x - reshaped_codewords).pow(2).sum(dim=3)
+        return scaled_l2_norm
+
+    @staticmethod
+    def aggregate(assignment_weights, x, codewords):
+        num_codes, channels = codewords.size()
+        reshaped_codewords = codewords.view((1, 1, num_codes, channels))
+        batch_size = x.size(0)
+
+        expanded_x = x.unsqueeze(2).expand(
+            (batch_size, x.size(1), num_codes, channels))
+        encoded_feat = (assignment_weights.unsqueeze(3) *
+                        (expanded_x - reshaped_codewords)).sum(dim=1)
+        return encoded_feat
+
+    def forward(self, x):
+        assert x.dim() == 4 and x.size(1) == self.channels
+        # [batch_size, channels, height, width]
+        batch_size = x.size(0)
+        # [batch_size, height x width, channels]
+        x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous()
+        # assignment_weights: [batch_size, channels, num_codes]
+        assignment_weights = F.softmax(
+            self.scaled_l2(x, self.codewords, self.scale), dim=2)
+        # aggregate
+        encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
+        return encoded_feat
+
+    def __repr__(self):
+        repr_str = self.__class__.__name__
+        repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \
+                    f'x{self.channels})'
+        return repr_str
diff --git a/annotator/uniformer/mmseg/ops/wrappers.py b/annotator/uniformer/mmseg/ops/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed9a0cb8d7c0e0ec2748dd89c652756653cac78
--- /dev/null
+++ b/annotator/uniformer/mmseg/ops/wrappers.py
@@ -0,0 +1,50 @@
+import warnings
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def resize(input,
+           size=None,
+           scale_factor=None,
+           mode='nearest',
+           align_corners=None,
+           warning=True):
+    if warning:
+        if size is not None and align_corners:
+            input_h, input_w = tuple(int(x) for x in input.shape[2:])
+            output_h, output_w = tuple(int(x) for x in size)
+            if output_h > input_h or output_w > output_h:
+                if ((output_h > 1 and output_w > 1 and input_h > 1
+                     and input_w > 1) and (output_h - 1) % (input_h - 1)
+                        and (output_w - 1) % (input_w - 1)):
+                    warnings.warn(
+                        f'When align_corners={align_corners}, '
+                        'the output would more aligned if '
+                        f'input size {(input_h, input_w)} is `x+1` and '
+                        f'out size {(output_h, output_w)} is `nx+1`')
+    return F.interpolate(input, size, scale_factor, mode, align_corners)
+
+
+class Upsample(nn.Module):
+
+    def __init__(self,
+                 size=None,
+                 scale_factor=None,
+                 mode='nearest',
+                 align_corners=None):
+        super(Upsample, self).__init__()
+        self.size = size
+        if isinstance(scale_factor, tuple):
+            self.scale_factor = tuple(float(factor) for factor in scale_factor)
+        else:
+            self.scale_factor = float(scale_factor) if scale_factor else None
+        self.mode = mode
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        if not self.size:
+            size = [int(t * self.scale_factor) for t in x.shape[-2:]]
+        else:
+            size = self.size
+        return resize(x, size, None, self.mode, self.align_corners)
diff --git a/annotator/uniformer/mmseg/utils/__init__.py b/annotator/uniformer/mmseg/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac489e2dbbc0e6fa87f5088b4edcc20f8cadc1a6
--- /dev/null
+++ b/annotator/uniformer/mmseg/utils/__init__.py
@@ -0,0 +1,4 @@
+from .collect_env import collect_env
+from .logger import get_root_logger
+
+__all__ = ['get_root_logger', 'collect_env']
diff --git a/annotator/uniformer/mmseg/utils/collect_env.py b/annotator/uniformer/mmseg/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..65c2134ddbee9655161237dd0894d38c768c2624
--- /dev/null
+++ b/annotator/uniformer/mmseg/utils/collect_env.py
@@ -0,0 +1,17 @@
+from annotator.uniformer.mmcv.utils import collect_env as collect_base_env
+from annotator.uniformer.mmcv.utils import get_git_hash
+
+import annotator.uniformer.mmseg as mmseg
+
+
+def collect_env():
+    """Collect the information of the running environments."""
+    env_info = collect_base_env()
+    env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
+
+    return env_info
+
+
+if __name__ == '__main__':
+    for name, val in collect_env().items():
+        print('{}: {}'.format(name, val))
diff --git a/annotator/uniformer/mmseg/utils/logger.py b/annotator/uniformer/mmseg/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..4149d9eda3dfef07490352d22ac40c42460315e4
--- /dev/null
+++ b/annotator/uniformer/mmseg/utils/logger.py
@@ -0,0 +1,27 @@
+import logging
+
+from annotator.uniformer.mmcv.utils import get_logger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+    """Get the root logger.
+
+    The logger will be initialized if it has not been initialized. By default a
+    StreamHandler will be added. If `log_file` is specified, a FileHandler will
+    also be added. The name of the root logger is the top-level package name,
+    e.g., "mmseg".
+
+    Args:
+        log_file (str | None): The log filename. If specified, a FileHandler
+            will be added to the root logger.
+        log_level (int): The root logger level. Note that only the process of
+            rank 0 is affected, while other processes will set the level to
+            "Error" and be silent most of the time.
+
+    Returns:
+        logging.Logger: The root logger.
+    """
+
+    logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level)
+
+    return logger
diff --git a/annotator/util.py b/annotator/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..90831643d19cc1b9b0940df3d4fd4d846ba74a05
--- /dev/null
+++ b/annotator/util.py
@@ -0,0 +1,38 @@
+import numpy as np
+import cv2
+import os
+
+
+annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
+
+
+def HWC3(x):
+    assert x.dtype == np.uint8
+    if x.ndim == 2:
+        x = x[:, :, None]
+    assert x.ndim == 3
+    H, W, C = x.shape
+    assert C == 1 or C == 3 or C == 4
+    if C == 3:
+        return x
+    if C == 1:
+        return np.concatenate([x, x, x], axis=2)
+    if C == 4:
+        color = x[:, :, 0:3].astype(np.float32)
+        alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+        y = color * alpha + 255.0 * (1.0 - alpha)
+        y = y.clip(0, 255).astype(np.uint8)
+        return y
+
+
+def resize_image(input_image, resolution):
+    H, W, C = input_image.shape
+    H = float(H)
+    W = float(W)
+    k = float(resolution) / min(H, W)
+    H *= k
+    W *= k
+    H = int(np.round(H / 64.0)) * 64
+    W = int(np.round(W / 64.0)) * 64
+    img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+    return img
diff --git a/app.py b/app.py
index 6592c957c7ef9dad34f25ce7aca8a17e2d3dc419..8f43b4188661f4e7eabd214a1c2f39aaa8c637eb 100644
--- a/app.py
+++ b/app.py
@@ -5,11 +5,8 @@ from streamlit_lottie import st_lottie
 from streamlit_option_menu import option_menu
 import requests
 import os 
-os.system('git clone https://github.com/lllyasviel/ControlNet.git')
-os.chdir('/home/user/app/ControlNet')
-
-from share import *
-import config
+# os.system('git clone https://github.com/lllyasviel/ControlNet.git')
+# os.chdir('/content/ControlNet')
 
 import cv2
 import einops
@@ -32,6 +29,8 @@ st.set_page_config(
         initial_sidebar_state="expanded"
     )
 
+save_memory = False
+
 @st.cache_resource
 def load_model():
     model_path = hf_hub_download('lllyasviel/ControlNet', 'models/control_sd15_scribble.pth')
@@ -64,14 +63,14 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
             seed = random.randint(0, 65535)
         seed_everything(seed)
 
-        if config.save_memory:
+        if save_memory:
             model.low_vram_shift(is_diffusing=False)
 
         cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
         un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
         shape = (4, H // 8, W // 8)
 
-        if config.save_memory:
+        if save_memory:
             model.low_vram_shift(is_diffusing=True)
 
         model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)  # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
@@ -80,7 +79,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti
                                                      unconditional_guidance_scale=scale,
                                                      unconditional_conditioning=un_cond)
 
-        if config.save_memory:
+        if save_memory:
             model.low_vram_shift(is_diffusing=False)
 
         x_samples = model.decode_first_stage(samples)
@@ -227,4 +226,4 @@ def main():
                     col32.image(output_image, channels='RGB', width=384, clamp=True, caption='Generated image')
         
 if __name__ == '__main__':
-    main()
\ No newline at end of file
+    main()
diff --git a/cldm/__pycache__/cldm.cpython-38.pyc b/cldm/__pycache__/cldm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e4a6ef1b71fe3f90d984369c72d5c4da5d81588d
Binary files /dev/null and b/cldm/__pycache__/cldm.cpython-38.pyc differ
diff --git a/cldm/__pycache__/ddim_hacked.cpython-38.pyc b/cldm/__pycache__/ddim_hacked.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a019a67866ce14028b679f22eddda9f34fd89c6b
Binary files /dev/null and b/cldm/__pycache__/ddim_hacked.cpython-38.pyc differ
diff --git a/cldm/__pycache__/model.cpython-38.pyc b/cldm/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5eda656718e885544e791b96abe9960c8e495351
Binary files /dev/null and b/cldm/__pycache__/model.cpython-38.pyc differ
diff --git a/cldm/cldm.py b/cldm/cldm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b3ac7a575cf4933fc14dfc15dd3cca41cb3f3e8
--- /dev/null
+++ b/cldm/cldm.py
@@ -0,0 +1,435 @@
+import einops
+import torch
+import torch as th
+import torch.nn as nn
+
+from ldm.modules.diffusionmodules.util import (
+    conv_nd,
+    linear,
+    zero_module,
+    timestep_embedding,
+)
+
+from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from ldm.modules.attention import SpatialTransformer
+from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
+from ldm.models.diffusion.ddpm import LatentDiffusion
+from ldm.util import log_txt_as_img, exists, instantiate_from_config
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+class ControlledUnetModel(UNetModel):
+    def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
+        hs = []
+        with torch.no_grad():
+            t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+            emb = self.time_embed(t_emb)
+            h = x.type(self.dtype)
+            for module in self.input_blocks:
+                h = module(h, emb, context)
+                hs.append(h)
+            h = self.middle_block(h, emb, context)
+
+        if control is not None:
+            h += control.pop()
+
+        for i, module in enumerate(self.output_blocks):
+            if only_mid_control or control is None:
+                h = torch.cat([h, hs.pop()], dim=1)
+            else:
+                h = torch.cat([h, hs.pop() + control.pop()], dim=1)
+            h = module(h, emb, context)
+
+        h = h.type(x.dtype)
+        return self.out(h)
+
+
+class ControlNet(nn.Module):
+    def __init__(
+            self,
+            image_size,
+            in_channels,
+            model_channels,
+            hint_channels,
+            num_res_blocks,
+            attention_resolutions,
+            dropout=0,
+            channel_mult=(1, 2, 4, 8),
+            conv_resample=True,
+            dims=2,
+            use_checkpoint=False,
+            use_fp16=False,
+            num_heads=-1,
+            num_head_channels=-1,
+            num_heads_upsample=-1,
+            use_scale_shift_norm=False,
+            resblock_updown=False,
+            use_new_attention_order=False,
+            use_spatial_transformer=False,  # custom transformer support
+            transformer_depth=1,  # custom transformer support
+            context_dim=None,  # custom transformer support
+            n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model
+            legacy=True,
+            disable_self_attentions=None,
+            num_attention_blocks=None,
+            disable_middle_self_attn=False,
+            use_linear_in_transformer=False,
+    ):
+        super().__init__()
+        if use_spatial_transformer:
+            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+        if context_dim is not None:
+            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+            from omegaconf.listconfig import ListConfig
+            if type(context_dim) == ListConfig:
+                context_dim = list(context_dim)
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        if num_heads == -1:
+            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+        if num_head_channels == -1:
+            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+        self.dims = dims
+        self.image_size = image_size
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        if isinstance(num_res_blocks, int):
+            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+        else:
+            if len(num_res_blocks) != len(channel_mult):
+                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+                                 "as a list/tuple (per-level) with the same length as channel_mult")
+            self.num_res_blocks = num_res_blocks
+        if disable_self_attentions is not None:
+            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+            assert len(disable_self_attentions) == len(channel_mult)
+        if num_attention_blocks is not None:
+            assert len(num_attention_blocks) == len(self.num_res_blocks)
+            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+                  f"attention will still not be set.")
+
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+        self.predict_codebook_ids = n_embed is not None
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        self.input_blocks = nn.ModuleList(
+            [
+                TimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
+
+        self.input_hint_block = TimestepEmbedSequential(
+            conv_nd(dims, hint_channels, 16, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 16, 16, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 16, 32, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 32, 32, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 32, 96, 3, padding=1, stride=2),
+            nn.SiLU(),
+            conv_nd(dims, 96, 96, 3, padding=1),
+            nn.SiLU(),
+            conv_nd(dims, 96, 256, 3, padding=1, stride=2),
+            nn.SiLU(),
+            zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
+        )
+
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for nr in range(self.num_res_blocks[level]):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=mult * model_channels,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        # num_heads = 1
+                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            ) if not use_spatial_transformer else SpatialTransformer(
+                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint
+                            )
+                        )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self.zero_convs.append(self.make_zero_conv(ch))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                self.zero_convs.append(self.make_zero_conv(ch))
+                ds *= 2
+                self._feature_size += ch
+
+        if num_head_channels == -1:
+            dim_head = ch // num_heads
+        else:
+            num_heads = ch // num_head_channels
+            dim_head = num_head_channels
+        if legacy:
+            # num_heads = 1
+            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+        self.middle_block = TimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=dim_head,
+                use_new_attention_order=use_new_attention_order,
+            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
+                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+                use_checkpoint=use_checkpoint
+            ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self.middle_block_out = self.make_zero_conv(ch)
+        self._feature_size += ch
+
+    def make_zero_conv(self, channels):
+        return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
+
+    def forward(self, x, hint, timesteps, context, **kwargs):
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+        emb = self.time_embed(t_emb)
+
+        guided_hint = self.input_hint_block(hint, emb, context)
+
+        outs = []
+
+        h = x.type(self.dtype)
+        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
+            if guided_hint is not None:
+                h = module(h, emb, context)
+                h += guided_hint
+                guided_hint = None
+            else:
+                h = module(h, emb, context)
+            outs.append(zero_conv(h, emb, context))
+
+        h = self.middle_block(h, emb, context)
+        outs.append(self.middle_block_out(h, emb, context))
+
+        return outs
+
+
+class ControlLDM(LatentDiffusion):
+
+    def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.control_model = instantiate_from_config(control_stage_config)
+        self.control_key = control_key
+        self.only_mid_control = only_mid_control
+        self.control_scales = [1.0] * 13
+
+    @torch.no_grad()
+    def get_input(self, batch, k, bs=None, *args, **kwargs):
+        x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
+        control = batch[self.control_key]
+        if bs is not None:
+            control = control[:bs]
+        control = control.to(self.device)
+        control = einops.rearrange(control, 'b h w c -> b c h w')
+        control = control.to(memory_format=torch.contiguous_format).float()
+        return x, dict(c_crossattn=[c], c_concat=[control])
+
+    def apply_model(self, x_noisy, t, cond, *args, **kwargs):
+        assert isinstance(cond, dict)
+        diffusion_model = self.model.diffusion_model
+
+        cond_txt = torch.cat(cond['c_crossattn'], 1)
+
+        if cond['c_concat'] is None:
+            eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
+        else:
+            control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
+            control = [c * scale for c, scale in zip(control, self.control_scales)]
+            eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
+
+        return eps
+
+    @torch.no_grad()
+    def get_unconditional_conditioning(self, N):
+        return self.get_learned_conditioning([""] * N)
+
+    @torch.no_grad()
+    def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
+                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+                   plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
+                   use_ema_scope=True,
+                   **kwargs):
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c = self.get_input(batch, self.first_stage_key, bs=N)
+        c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
+        N = min(z.shape[0], N)
+        n_row = min(z.shape[0], n_row)
+        log["reconstruction"] = self.decode_first_stage(z)
+        log["control"] = c_cat * 2.0 - 1.0
+        log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                                                     batch_size=N, ddim=use_ddim,
+                                                     ddim_steps=ddim_steps, eta=ddim_eta)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_cross = self.get_unconditional_conditioning(N)
+            uc_cat = c_cat  # torch.zeros_like(c_cat)
+            uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+            samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                                             batch_size=N, ddim=use_ddim,
+                                             ddim_steps=ddim_steps, eta=ddim_eta,
+                                             unconditional_guidance_scale=unconditional_guidance_scale,
+                                             unconditional_conditioning=uc_full,
+                                             )
+            x_samples_cfg = self.decode_first_stage(samples_cfg)
+            log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        return log
+
+    @torch.no_grad()
+    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+        ddim_sampler = DDIMSampler(self)
+        b, c, h, w = cond["c_concat"][0].shape
+        shape = (self.channels, h // 8, w // 8)
+        samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
+        return samples, intermediates
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.control_model.parameters())
+        if not self.sd_locked:
+            params += list(self.model.diffusion_model.output_blocks.parameters())
+            params += list(self.model.diffusion_model.out.parameters())
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+    def low_vram_shift(self, is_diffusing):
+        if is_diffusing:
+            self.model = self.model.cuda()
+            self.control_model = self.control_model.cuda()
+            self.first_stage_model = self.first_stage_model.cpu()
+            self.cond_stage_model = self.cond_stage_model.cpu()
+        else:
+            self.model = self.model.cpu()
+            self.control_model = self.control_model.cpu()
+            self.first_stage_model = self.first_stage_model.cuda()
+            self.cond_stage_model = self.cond_stage_model.cuda()
diff --git a/cldm/ddim_hacked.py b/cldm/ddim_hacked.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c040b363ba0705f52509b75437b5ea932c80ec1
--- /dev/null
+++ b/cldm/ddim_hacked.py
@@ -0,0 +1,316 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+        alphas_cumprod = self.model.alphas_cumprod
+        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer('betas', to_torch(self.model.betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+                                                                                   ddim_timesteps=self.ddim_timesteps,
+                                                                                   eta=ddim_eta,verbose=verbose)
+        self.register_buffer('ddim_sigmas', ddim_sigmas)
+        self.register_buffer('ddim_alphas', ddim_alphas)
+        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               dynamic_threshold=None,
+               ucg_schedule=None,
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                ctmp = conditioning[list(conditioning.keys())[0]]
+                while isinstance(ctmp, list): ctmp = ctmp[0]
+                cbs = ctmp.shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            elif isinstance(conditioning, list):
+                for ctmp in conditioning:
+                    if ctmp.shape[0] != batch_size:
+                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+        samples, intermediates = self.ddim_sampling(conditioning, size,
+                                                    callback=callback,
+                                                    img_callback=img_callback,
+                                                    quantize_denoised=quantize_x0,
+                                                    mask=mask, x0=x0,
+                                                    ddim_use_original_steps=False,
+                                                    noise_dropout=noise_dropout,
+                                                    temperature=temperature,
+                                                    score_corrector=score_corrector,
+                                                    corrector_kwargs=corrector_kwargs,
+                                                    x_T=x_T,
+                                                    log_every_t=log_every_t,
+                                                    unconditional_guidance_scale=unconditional_guidance_scale,
+                                                    unconditional_conditioning=unconditional_conditioning,
+                                                    dynamic_threshold=dynamic_threshold,
+                                                    ucg_schedule=ucg_schedule
+                                                    )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(self, cond, shape,
+                      x_T=None, ddim_use_original_steps=False,
+                      callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, log_every_t=100,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+                      ucg_schedule=None):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {'x_inter': [img], 'pred_x0': [img]}
+        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1. - mask) * img
+
+            if ucg_schedule is not None:
+                assert len(ucg_schedule) == len(time_range)
+                unconditional_guidance_scale = ucg_schedule[i]
+
+            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+                                      quantize_denoised=quantize_denoised, temperature=temperature,
+                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
+                                      corrector_kwargs=corrector_kwargs,
+                                      unconditional_guidance_scale=unconditional_guidance_scale,
+                                      unconditional_conditioning=unconditional_conditioning,
+                                      dynamic_threshold=dynamic_threshold)
+            img, pred_x0 = outs
+            if callback: callback(i)
+            if img_callback: img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates['x_inter'].append(img)
+                intermediates['pred_x0'].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
+                      dynamic_threshold=None):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            model_t = self.model.apply_model(x, t, c)
+            model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
+            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+        if self.model.parameterization == "v":
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", 'not implemented'
+            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+        num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+        assert t_enc <= num_reference_steps
+        num_steps = t_enc
+
+        if use_original_steps:
+            alphas_next = self.alphas_cumprod[:num_steps]
+            alphas = self.alphas_cumprod_prev[:num_steps]
+        else:
+            alphas_next = self.ddim_alphas[:num_steps]
+            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+        x_next = x0
+        intermediates = []
+        inter_steps = []
+        for i in tqdm(range(num_steps), desc='Encoding Image'):
+            t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+            if unconditional_guidance_scale == 1.:
+                noise_pred = self.model.apply_model(x_next, t, c)
+            else:
+                assert unconditional_conditioning is not None
+                e_t_uncond, noise_pred = torch.chunk(
+                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+                                           torch.cat((unconditional_conditioning, c))), 2)
+                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+            weighted_noise_pred = alphas_next[i].sqrt() * (
+                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+            x_next = xt_weighted + weighted_noise_pred
+            if return_intermediates and i % (
+                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            elif return_intermediates and i >= num_steps - 2:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            if callback: callback(i)
+
+        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+        if return_intermediates:
+            out.update({'intermediates': intermediates})
+        return x_next, out
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+    @torch.no_grad()
+    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+               use_original_steps=False, callback=None):
+
+        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+                                          unconditional_guidance_scale=unconditional_guidance_scale,
+                                          unconditional_conditioning=unconditional_conditioning)
+            if callback: callback(i)
+        return x_dec
\ No newline at end of file
diff --git a/cldm/hack.py b/cldm/hack.py
new file mode 100644
index 0000000000000000000000000000000000000000..454361e9d036cd1a6a79122c2fd16b489e4767b1
--- /dev/null
+++ b/cldm/hack.py
@@ -0,0 +1,111 @@
+import torch
+import einops
+
+import ldm.modules.encoders.modules
+import ldm.modules.attention
+
+from transformers import logging
+from ldm.modules.attention import default
+
+
+def disable_verbosity():
+    logging.set_verbosity_error()
+    print('logging improved.')
+    return
+
+
+def enable_sliced_attention():
+    ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
+    print('Enabled sliced_attention.')
+    return
+
+
+def hack_everything(clip_skip=0):
+    disable_verbosity()
+    ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
+    ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
+    print('Enabled clip hacks.')
+    return
+
+
+# Written by Lvmin
+def _hacked_clip_forward(self, text):
+    PAD = self.tokenizer.pad_token_id
+    EOS = self.tokenizer.eos_token_id
+    BOS = self.tokenizer.bos_token_id
+
+    def tokenize(t):
+        return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
+
+    def transformer_encode(t):
+        if self.clip_skip > 1:
+            rt = self.transformer(input_ids=t, output_hidden_states=True)
+            return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
+        else:
+            return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
+
+    def split(x):
+        return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
+
+    def pad(x, p, i):
+        return x[:i] if len(x) >= i else x + [p] * (i - len(x))
+
+    raw_tokens_list = tokenize(text)
+    tokens_list = []
+
+    for raw_tokens in raw_tokens_list:
+        raw_tokens_123 = split(raw_tokens)
+        raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
+        raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
+        tokens_list.append(raw_tokens_123)
+
+    tokens_list = torch.IntTensor(tokens_list).to(self.device)
+
+    feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
+    y = transformer_encode(feed)
+    z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
+
+    return z
+
+
+# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
+def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
+    h = self.heads
+
+    q = self.to_q(x)
+    context = default(context, x)
+    k = self.to_k(context)
+    v = self.to_v(context)
+    del context, x
+
+    q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+    limit = k.shape[0]
+    att_step = 1
+    q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
+    k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
+    v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
+
+    q_chunks.reverse()
+    k_chunks.reverse()
+    v_chunks.reverse()
+    sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
+    del k, q, v
+    for i in range(0, limit, att_step):
+        q_buffer = q_chunks.pop()
+        k_buffer = k_chunks.pop()
+        v_buffer = v_chunks.pop()
+        sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
+
+        del k_buffer, q_buffer
+        # attention, what we cannot get enough of, by chunks
+
+        sim_buffer = sim_buffer.softmax(dim=-1)
+
+        sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
+        del v_buffer
+        sim[i:i + att_step, :, :] = sim_buffer
+
+        del sim_buffer
+    sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+    return self.to_out(sim)
diff --git a/cldm/logger.py b/cldm/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8803846f2a8979f87f3cf9ea5b12869439e62f
--- /dev/null
+++ b/cldm/logger.py
@@ -0,0 +1,76 @@
+import os
+
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.utilities.distributed import rank_zero_only
+
+
+class ImageLogger(Callback):
+    def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
+                 rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
+                 log_images_kwargs=None):
+        super().__init__()
+        self.rescale = rescale
+        self.batch_freq = batch_frequency
+        self.max_images = max_images
+        if not increase_log_steps:
+            self.log_steps = [self.batch_freq]
+        self.clamp = clamp
+        self.disabled = disabled
+        self.log_on_batch_idx = log_on_batch_idx
+        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+        self.log_first_step = log_first_step
+
+    @rank_zero_only
+    def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
+        root = os.path.join(save_dir, "image_log", split)
+        for k in images:
+            grid = torchvision.utils.make_grid(images[k], nrow=4)
+            if self.rescale:
+                grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w
+            grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+            grid = grid.numpy()
+            grid = (grid * 255).astype(np.uint8)
+            filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
+            path = os.path.join(root, filename)
+            os.makedirs(os.path.split(path)[0], exist_ok=True)
+            Image.fromarray(grid).save(path)
+
+    def log_img(self, pl_module, batch, batch_idx, split="train"):
+        check_idx = batch_idx  # if self.log_on_batch_idx else pl_module.global_step
+        if (self.check_frequency(check_idx) and  # batch_idx % self.batch_freq == 0
+                hasattr(pl_module, "log_images") and
+                callable(pl_module.log_images) and
+                self.max_images > 0):
+            logger = type(pl_module.logger)
+
+            is_train = pl_module.training
+            if is_train:
+                pl_module.eval()
+
+            with torch.no_grad():
+                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
+
+            for k in images:
+                N = min(images[k].shape[0], self.max_images)
+                images[k] = images[k][:N]
+                if isinstance(images[k], torch.Tensor):
+                    images[k] = images[k].detach().cpu()
+                    if self.clamp:
+                        images[k] = torch.clamp(images[k], -1., 1.)
+
+            self.log_local(pl_module.logger.save_dir, split, images,
+                           pl_module.global_step, pl_module.current_epoch, batch_idx)
+
+            if is_train:
+                pl_module.train()
+
+    def check_frequency(self, check_idx):
+        return check_idx % self.batch_freq == 0
+
+    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
+        if not self.disabled:
+            self.log_img(pl_module, batch, batch_idx, split="train")
diff --git a/cldm/model.py b/cldm/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed3c31ac145b78907c7f771d1d8db6fb32d92ed
--- /dev/null
+++ b/cldm/model.py
@@ -0,0 +1,28 @@
+import os
+import torch
+
+from omegaconf import OmegaConf
+from ldm.util import instantiate_from_config
+
+
+def get_state_dict(d):
+    return d.get('state_dict', d)
+
+
+def load_state_dict(ckpt_path, location='cpu'):
+    _, extension = os.path.splitext(ckpt_path)
+    if extension.lower() == ".safetensors":
+        import safetensors.torch
+        state_dict = safetensors.torch.load_file(ckpt_path, device=location)
+    else:
+        state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
+    state_dict = get_state_dict(state_dict)
+    print(f'Loaded state_dict from [{ckpt_path}]')
+    return state_dict
+
+
+def create_model(config_path):
+    config = OmegaConf.load(config_path)
+    model = instantiate_from_config(config.model).cpu()
+    print(f'Loaded model config from [{config_path}]')
+    return model
diff --git a/ldm/__pycache__/util.cpython-38.pyc b/ldm/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9708c37a68699410e5d6153b25d207ca1ac0a0ec
Binary files /dev/null and b/ldm/__pycache__/util.cpython-38.pyc differ
diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/data/util.py b/ldm/data/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b60ceb2349e3bd7900ff325740e2022d2903b1c
--- /dev/null
+++ b/ldm/data/util.py
@@ -0,0 +1,24 @@
+import torch
+
+from ldm.modules.midas.api import load_midas_transform
+
+
+class AddMiDaS(object):
+    def __init__(self, model_type):
+        super().__init__()
+        self.transform = load_midas_transform(model_type)
+
+    def pt2np(self, x):
+        x = ((x + 1.0) * .5).detach().cpu().numpy()
+        return x
+
+    def np2pt(self, x):
+        x = torch.from_numpy(x) * 2 - 1.
+        return x
+
+    def __call__(self, sample):
+        # sample['jpg'] is tensor hwc in [-1, 1] at this point
+        x = self.pt2np(sample['jpg'])
+        x = self.transform({"image": x})["image"]
+        sample['midas_in'] = x
+        return sample
\ No newline at end of file
diff --git a/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/ldm/models/__pycache__/autoencoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..149d7395207d03a3d8d10ed072835a9bff283725
Binary files /dev/null and b/ldm/models/__pycache__/autoencoder.cpython-38.pyc differ
diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d122549995ce2cd64092c81a58419ed4a15a02fd
--- /dev/null
+++ b/ldm/models/autoencoder.py
@@ -0,0 +1,219 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+from ldm.modules.ema import LitEma
+
+
+class AutoencoderKL(pl.LightningModule):
+    def __init__(self,
+                 ddconfig,
+                 lossconfig,
+                 embed_dim,
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 image_key="image",
+                 colorize_nlabels=None,
+                 monitor=None,
+                 ema_decay=None,
+                 learn_logvar=False
+                 ):
+        super().__init__()
+        self.learn_logvar = learn_logvar
+        self.image_key = image_key
+        self.encoder = Encoder(**ddconfig)
+        self.decoder = Decoder(**ddconfig)
+        self.loss = instantiate_from_config(lossconfig)
+        assert ddconfig["double_z"]
+        self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+        self.embed_dim = embed_dim
+        if colorize_nlabels is not None:
+            assert type(colorize_nlabels)==int
+            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+        if monitor is not None:
+            self.monitor = monitor
+
+        self.use_ema = ema_decay is not None
+        if self.use_ema:
+            self.ema_decay = ema_decay
+            assert 0. < ema_decay < 1.
+            self.model_ema = LitEma(self, decay=ema_decay)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+    def init_from_ckpt(self, path, ignore_keys=list()):
+        sd = torch.load(path, map_location="cpu")["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        self.load_state_dict(sd, strict=False)
+        print(f"Restored from {path}")
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.parameters())
+            self.model_ema.copy_to(self)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self)
+
+    def encode(self, x):
+        h = self.encoder(x)
+        moments = self.quant_conv(h)
+        posterior = DiagonalGaussianDistribution(moments)
+        return posterior
+
+    def decode(self, z):
+        z = self.post_quant_conv(z)
+        dec = self.decoder(z)
+        return dec
+
+    def forward(self, input, sample_posterior=True):
+        posterior = self.encode(input)
+        if sample_posterior:
+            z = posterior.sample()
+        else:
+            z = posterior.mode()
+        dec = self.decode(z)
+        return dec, posterior
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def training_step(self, batch, batch_idx, optimizer_idx):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+
+        if optimizer_idx == 0:
+            # train encoder+decoder+logvar
+            aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+                                            last_layer=self.get_last_layer(), split="train")
+            self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+            return aeloss
+
+        if optimizer_idx == 1:
+            # train the discriminator
+            discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+                                                last_layer=self.get_last_layer(), split="train")
+
+            self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+            return discloss
+
+    def validation_step(self, batch, batch_idx):
+        log_dict = self._validation_step(batch, batch_idx)
+        with self.ema_scope():
+            log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+        return log_dict
+
+    def _validation_step(self, batch, batch_idx, postfix=""):
+        inputs = self.get_input(batch, self.image_key)
+        reconstructions, posterior = self(inputs)
+        aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+                                        last_layer=self.get_last_layer(), split="val"+postfix)
+
+        discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+                                            last_layer=self.get_last_layer(), split="val"+postfix)
+
+        self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+        self.log_dict(log_dict_ae)
+        self.log_dict(log_dict_disc)
+        return self.log_dict
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+            self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+        if self.learn_logvar:
+            print(f"{self.__class__.__name__}: Learning logvar")
+            ae_params_list.append(self.loss.logvar)
+        opt_ae = torch.optim.Adam(ae_params_list,
+                                  lr=lr, betas=(0.5, 0.9))
+        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+                                    lr=lr, betas=(0.5, 0.9))
+        return [opt_ae, opt_disc], []
+
+    def get_last_layer(self):
+        return self.decoder.conv_out.weight
+
+    @torch.no_grad()
+    def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.image_key)
+        x = x.to(self.device)
+        if not only_inputs:
+            xrec, posterior = self(x)
+            if x.shape[1] > 3:
+                # colorize with random projection
+                assert xrec.shape[1] > 3
+                x = self.to_rgb(x)
+                xrec = self.to_rgb(xrec)
+            log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+            log["reconstructions"] = xrec
+            if log_ema or self.use_ema:
+                with self.ema_scope():
+                    xrec_ema, posterior_ema = self(x)
+                    if x.shape[1] > 3:
+                        # colorize with random projection
+                        assert xrec_ema.shape[1] > 3
+                        xrec_ema = self.to_rgb(xrec_ema)
+                    log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
+                    log["reconstructions_ema"] = xrec_ema
+        log["inputs"] = x
+        return log
+
+    def to_rgb(self, x):
+        assert self.image_key == "segmentation"
+        if not hasattr(self, "colorize"):
+            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+        x = F.conv2d(x, weight=self.colorize)
+        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+        return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+    def __init__(self, *args, vq_interface=False, **kwargs):
+        self.vq_interface = vq_interface
+        super().__init__()
+
+    def encode(self, x, *args, **kwargs):
+        return x
+
+    def decode(self, x, *args, **kwargs):
+        return x
+
+    def quantize(self, x, *args, **kwargs):
+        if self.vq_interface:
+            return x, None, [None, None, None]
+        return x
+
+    def forward(self, x, *args, **kwargs):
+        return x
+
diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a28bd02bc035a48e4077212a45f232a5922b06a
Binary files /dev/null and b/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c52b723f1e4d09df236fd98e21facfd0a6fbef68
Binary files /dev/null and b/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc differ
diff --git a/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2288e94b0e1a9e2db5f7a24829aecb2350ba12eb
Binary files /dev/null and b/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc differ
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..27ead0ea914c64c747b64e690662899fb3801144
--- /dev/null
+++ b/ldm/models/diffusion/ddim.py
@@ -0,0 +1,336 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+        alphas_cumprod = self.model.alphas_cumprod
+        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer('betas', to_torch(self.model.betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+                                                                                   ddim_timesteps=self.ddim_timesteps,
+                                                                                   eta=ddim_eta,verbose=verbose)
+        self.register_buffer('ddim_sigmas', ddim_sigmas)
+        self.register_buffer('ddim_alphas', ddim_alphas)
+        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               dynamic_threshold=None,
+               ucg_schedule=None,
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                ctmp = conditioning[list(conditioning.keys())[0]]
+                while isinstance(ctmp, list): ctmp = ctmp[0]
+                cbs = ctmp.shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            elif isinstance(conditioning, list):
+                for ctmp in conditioning:
+                    if ctmp.shape[0] != batch_size:
+                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+        samples, intermediates = self.ddim_sampling(conditioning, size,
+                                                    callback=callback,
+                                                    img_callback=img_callback,
+                                                    quantize_denoised=quantize_x0,
+                                                    mask=mask, x0=x0,
+                                                    ddim_use_original_steps=False,
+                                                    noise_dropout=noise_dropout,
+                                                    temperature=temperature,
+                                                    score_corrector=score_corrector,
+                                                    corrector_kwargs=corrector_kwargs,
+                                                    x_T=x_T,
+                                                    log_every_t=log_every_t,
+                                                    unconditional_guidance_scale=unconditional_guidance_scale,
+                                                    unconditional_conditioning=unconditional_conditioning,
+                                                    dynamic_threshold=dynamic_threshold,
+                                                    ucg_schedule=ucg_schedule
+                                                    )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def ddim_sampling(self, cond, shape,
+                      x_T=None, ddim_use_original_steps=False,
+                      callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, log_every_t=100,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+                      ucg_schedule=None):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {'x_inter': [img], 'pred_x0': [img]}
+        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1. - mask) * img
+
+            if ucg_schedule is not None:
+                assert len(ucg_schedule) == len(time_range)
+                unconditional_guidance_scale = ucg_schedule[i]
+
+            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+                                      quantize_denoised=quantize_denoised, temperature=temperature,
+                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
+                                      corrector_kwargs=corrector_kwargs,
+                                      unconditional_guidance_scale=unconditional_guidance_scale,
+                                      unconditional_conditioning=unconditional_conditioning,
+                                      dynamic_threshold=dynamic_threshold)
+            img, pred_x0 = outs
+            if callback: callback(i)
+            if img_callback: img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates['x_inter'].append(img)
+                intermediates['pred_x0'].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
+                      dynamic_threshold=None):
+        b, *_, device = *x.shape, x.device
+
+        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+            model_output = self.model.apply_model(x, t, c)
+        else:
+            x_in = torch.cat([x] * 2)
+            t_in = torch.cat([t] * 2)
+            if isinstance(c, dict):
+                assert isinstance(unconditional_conditioning, dict)
+                c_in = dict()
+                for k in c:
+                    if isinstance(c[k], list):
+                        c_in[k] = [torch.cat([
+                            unconditional_conditioning[k][i],
+                            c[k][i]]) for i in range(len(c[k]))]
+                    else:
+                        c_in[k] = torch.cat([
+                                unconditional_conditioning[k],
+                                c[k]])
+            elif isinstance(c, list):
+                c_in = list()
+                assert isinstance(unconditional_conditioning, list)
+                for i in range(len(c)):
+                    c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+            else:
+                c_in = torch.cat([unconditional_conditioning, c])
+            model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+            model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+        if self.model.parameterization == "v":
+            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+        else:
+            e_t = model_output
+
+        if score_corrector is not None:
+            assert self.model.parameterization == "eps", 'not implemented'
+            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+        # select parameters corresponding to the currently considered timestep
+        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+        # current prediction for x_0
+        if self.model.parameterization != "v":
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+        else:
+            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+        if quantize_denoised:
+            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+        if dynamic_threshold is not None:
+            raise NotImplementedError()
+
+        # direction pointing to x_t
+        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+        return x_prev, pred_x0
+
+    @torch.no_grad()
+    def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+               unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+        num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+        assert t_enc <= num_reference_steps
+        num_steps = t_enc
+
+        if use_original_steps:
+            alphas_next = self.alphas_cumprod[:num_steps]
+            alphas = self.alphas_cumprod_prev[:num_steps]
+        else:
+            alphas_next = self.ddim_alphas[:num_steps]
+            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+        x_next = x0
+        intermediates = []
+        inter_steps = []
+        for i in tqdm(range(num_steps), desc='Encoding Image'):
+            t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+            if unconditional_guidance_scale == 1.:
+                noise_pred = self.model.apply_model(x_next, t, c)
+            else:
+                assert unconditional_conditioning is not None
+                e_t_uncond, noise_pred = torch.chunk(
+                    self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+                                           torch.cat((unconditional_conditioning, c))), 2)
+                noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+            weighted_noise_pred = alphas_next[i].sqrt() * (
+                    (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+            x_next = xt_weighted + weighted_noise_pred
+            if return_intermediates and i % (
+                    num_steps // return_intermediates) == 0 and i < num_steps - 1:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            elif return_intermediates and i >= num_steps - 2:
+                intermediates.append(x_next)
+                inter_steps.append(i)
+            if callback: callback(i)
+
+        out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+        if return_intermediates:
+            out.update({'intermediates': intermediates})
+        return x_next, out
+
+    @torch.no_grad()
+    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+        # fast, but does not allow for exact reconstruction
+        # t serves as an index to gather the correct alphas
+        if use_original_steps:
+            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+        else:
+            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+        if noise is None:
+            noise = torch.randn_like(x0)
+        return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+                extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+    @torch.no_grad()
+    def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+               use_original_steps=False, callback=None):
+
+        timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+        timesteps = timesteps[:t_start]
+
+        time_range = np.flip(timesteps)
+        total_steps = timesteps.shape[0]
+        print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+        x_dec = x_latent
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+            x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+                                          unconditional_guidance_scale=unconditional_guidance_scale,
+                                          unconditional_conditioning=unconditional_conditioning)
+            if callback: callback(i)
+        return x_dec
\ No newline at end of file
diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f71a44af48c8cba8e97849b7e6813b3e6f9fe83c
--- /dev/null
+++ b/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1797 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from omegaconf import ListConfig
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+                         'crossattn': 'c_crossattn',
+                         'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+    return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+    # classic DDPM with Gaussian diffusion, in image space
+    def __init__(self,
+                 unet_config,
+                 timesteps=1000,
+                 beta_schedule="linear",
+                 loss_type="l2",
+                 ckpt_path=None,
+                 ignore_keys=[],
+                 load_only_unet=False,
+                 monitor="val/loss",
+                 use_ema=True,
+                 first_stage_key="image",
+                 image_size=256,
+                 channels=3,
+                 log_every_t=100,
+                 clip_denoised=True,
+                 linear_start=1e-4,
+                 linear_end=2e-2,
+                 cosine_s=8e-3,
+                 given_betas=None,
+                 original_elbo_weight=0.,
+                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+                 l_simple_weight=1.,
+                 conditioning_key=None,
+                 parameterization="eps",  # all assuming fixed variance schedules
+                 scheduler_config=None,
+                 use_positional_encodings=False,
+                 learn_logvar=False,
+                 logvar_init=0.,
+                 make_it_fit=False,
+                 ucg_training=None,
+                 reset_ema=False,
+                 reset_num_ema_updates=False,
+                 ):
+        super().__init__()
+        assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+        self.parameterization = parameterization
+        print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+        self.cond_stage_model = None
+        self.clip_denoised = clip_denoised
+        self.log_every_t = log_every_t
+        self.first_stage_key = first_stage_key
+        self.image_size = image_size  # try conv?
+        self.channels = channels
+        self.use_positional_encodings = use_positional_encodings
+        self.model = DiffusionWrapper(unet_config, conditioning_key)
+        count_params(self.model, verbose=True)
+        self.use_ema = use_ema
+        if self.use_ema:
+            self.model_ema = LitEma(self.model)
+            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+        self.use_scheduler = scheduler_config is not None
+        if self.use_scheduler:
+            self.scheduler_config = scheduler_config
+
+        self.v_posterior = v_posterior
+        self.original_elbo_weight = original_elbo_weight
+        self.l_simple_weight = l_simple_weight
+
+        if monitor is not None:
+            self.monitor = monitor
+        self.make_it_fit = make_it_fit
+        if reset_ema: assert exists(ckpt_path)
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+            if reset_ema:
+                assert self.use_ema
+                print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+                self.model_ema = LitEma(self.model)
+        if reset_num_ema_updates:
+            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+            assert self.use_ema
+            self.model_ema.reset_num_updates()
+
+        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+        self.loss_type = loss_type
+
+        self.learn_logvar = learn_logvar
+        logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+        if self.learn_logvar:
+            self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+        else:
+            self.register_buffer('logvar', logvar)
+
+        self.ucg_training = ucg_training or dict()
+        if self.ucg_training:
+            self.ucg_prng = np.random.RandomState()
+
+    def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        if exists(given_betas):
+            betas = given_betas
+        else:
+            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+                                       cosine_s=cosine_s)
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        timesteps, = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+        # calculations for posterior q(x_{t-1} | x_t, x_0)
+        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+                1. - alphas_cumprod) + self.v_posterior * betas
+        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+        self.register_buffer('posterior_variance', to_torch(posterior_variance))
+        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+        self.register_buffer('posterior_mean_coef1', to_torch(
+            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+        self.register_buffer('posterior_mean_coef2', to_torch(
+            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+        if self.parameterization == "eps":
+            lvlb_weights = self.betas ** 2 / (
+                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+        elif self.parameterization == "x0":
+            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+        elif self.parameterization == "v":
+            lvlb_weights = torch.ones_like(self.betas ** 2 / (
+                    2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+        else:
+            raise NotImplementedError("mu not supported")
+        lvlb_weights[0] = lvlb_weights[1]
+        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+        assert not torch.isnan(self.lvlb_weights).all()
+
+    @contextmanager
+    def ema_scope(self, context=None):
+        if self.use_ema:
+            self.model_ema.store(self.model.parameters())
+            self.model_ema.copy_to(self.model)
+            if context is not None:
+                print(f"{context}: Switched to EMA weights")
+        try:
+            yield None
+        finally:
+            if self.use_ema:
+                self.model_ema.restore(self.model.parameters())
+                if context is not None:
+                    print(f"{context}: Restored training weights")
+
+    @torch.no_grad()
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+        if self.make_it_fit:
+            n_params = len([name for name, _ in
+                            itertools.chain(self.named_parameters(),
+                                            self.named_buffers())])
+            for name, param in tqdm(
+                    itertools.chain(self.named_parameters(),
+                                    self.named_buffers()),
+                    desc="Fitting old weights to new weights",
+                    total=n_params
+            ):
+                if not name in sd:
+                    continue
+                old_shape = sd[name].shape
+                new_shape = param.shape
+                assert len(old_shape) == len(new_shape)
+                if len(new_shape) > 2:
+                    # we only modify first two axes
+                    assert new_shape[2:] == old_shape[2:]
+                # assumes first axis corresponds to output dim
+                if not new_shape == old_shape:
+                    new_param = param.clone()
+                    old_param = sd[name]
+                    if len(new_shape) == 1:
+                        for i in range(new_param.shape[0]):
+                            new_param[i] = old_param[i % old_shape[0]]
+                    elif len(new_shape) >= 2:
+                        for i in range(new_param.shape[0]):
+                            for j in range(new_param.shape[1]):
+                                new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+
+                        n_used_old = torch.ones(old_shape[1])
+                        for j in range(new_param.shape[1]):
+                            n_used_old[j % old_shape[1]] += 1
+                        n_used_new = torch.zeros(new_shape[1])
+                        for j in range(new_param.shape[1]):
+                            n_used_new[j] = n_used_old[j % old_shape[1]]
+
+                        n_used_new = n_used_new[None, :]
+                        while len(n_used_new.shape) < len(new_shape):
+                            n_used_new = n_used_new.unsqueeze(-1)
+                        new_param /= n_used_new
+
+                    sd[name] = new_param
+
+        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+            sd, strict=False)
+        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+        if len(missing) > 0:
+            print(f"Missing Keys:\n {missing}")
+        if len(unexpected) > 0:
+            print(f"\nUnexpected Keys:\n {unexpected}")
+
+    def q_mean_variance(self, x_start, t):
+        """
+        Get the distribution q(x_t | x_0).
+        :param x_start: the [N x C x ...] tensor of noiseless inputs.
+        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+        """
+        mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+        log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+        return mean, variance, log_variance
+
+    def predict_start_from_noise(self, x_t, t, noise):
+        return (
+                extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+                extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+        )
+
+    def predict_start_from_z_and_v(self, x_t, t, v):
+        # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        return (
+                extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+        )
+
+    def predict_eps_from_z_and_v(self, x_t, t, v):
+        return (
+                extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+        )
+
+    def q_posterior(self, x_start, x_t, t):
+        posterior_mean = (
+                extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+                extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+        )
+        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+        posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+        return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+    def p_mean_variance(self, x, t, clip_denoised: bool):
+        model_out = self.model(x, t)
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+        b, *_, device = *x.shape, x.device
+        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+        noise = noise_like(x.shape, device, repeat_noise)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def p_sample_loop(self, shape, return_intermediates=False):
+        device = self.betas.device
+        b = shape[0]
+        img = torch.randn(shape, device=device)
+        intermediates = [img]
+        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+                                clip_denoised=self.clip_denoised)
+            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+                intermediates.append(img)
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, batch_size=16, return_intermediates=False):
+        image_size = self.image_size
+        channels = self.channels
+        return self.p_sample_loop((batch_size, channels, image_size, image_size),
+                                  return_intermediates=return_intermediates)
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+    def get_v(self, x, noise, t):
+        return (
+                extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+        )
+
+    def get_loss(self, pred, target, mean=True):
+        if self.loss_type == 'l1':
+            loss = (target - pred).abs()
+            if mean:
+                loss = loss.mean()
+        elif self.loss_type == 'l2':
+            if mean:
+                loss = torch.nn.functional.mse_loss(target, pred)
+            else:
+                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+        else:
+            raise NotImplementedError("unknown loss type '{loss_type}'")
+
+        return loss
+
+    def p_losses(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        model_out = self.model(x_noisy, t)
+
+        loss_dict = {}
+        if self.parameterization == "eps":
+            target = noise
+        elif self.parameterization == "x0":
+            target = x_start
+        elif self.parameterization == "v":
+            target = self.get_v(x_start, noise, t)
+        else:
+            raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
+
+        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+        log_prefix = 'train' if self.training else 'val'
+
+        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+        loss_simple = loss.mean() * self.l_simple_weight
+
+        loss_vlb = (self.lvlb_weights[t] * loss).mean()
+        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+        loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+        loss_dict.update({f'{log_prefix}/loss': loss})
+
+        return loss, loss_dict
+
+    def forward(self, x, *args, **kwargs):
+        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+        return self.p_losses(x, t, *args, **kwargs)
+
+    def get_input(self, batch, k):
+        x = batch[k]
+        if len(x.shape) == 3:
+            x = x[..., None]
+        x = rearrange(x, 'b h w c -> b c h w')
+        x = x.to(memory_format=torch.contiguous_format).float()
+        return x
+
+    def shared_step(self, batch):
+        x = self.get_input(batch, self.first_stage_key)
+        loss, loss_dict = self(x)
+        return loss, loss_dict
+
+    def training_step(self, batch, batch_idx):
+        for k in self.ucg_training:
+            p = self.ucg_training[k]["p"]
+            val = self.ucg_training[k]["val"]
+            if val is None:
+                val = ""
+            for i in range(len(batch[k])):
+                if self.ucg_prng.choice(2, p=[1 - p, p]):
+                    batch[k][i] = val
+
+        loss, loss_dict = self.shared_step(batch)
+
+        self.log_dict(loss_dict, prog_bar=True,
+                      logger=True, on_step=True, on_epoch=True)
+
+        self.log("global_step", self.global_step,
+                 prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+        if self.use_scheduler:
+            lr = self.optimizers().param_groups[0]['lr']
+            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+        return loss
+
+    @torch.no_grad()
+    def validation_step(self, batch, batch_idx):
+        _, loss_dict_no_ema = self.shared_step(batch)
+        with self.ema_scope():
+            _, loss_dict_ema = self.shared_step(batch)
+            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+    def on_train_batch_end(self, *args, **kwargs):
+        if self.use_ema:
+            self.model_ema(self.model)
+
+    def _get_rows_from_list(self, samples):
+        n_imgs_per_row = len(samples)
+        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+        log = dict()
+        x = self.get_input(batch, self.first_stage_key)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        x = x.to(self.device)[:N]
+        log["inputs"] = x
+
+        # get diffusion row
+        diffusion_row = list()
+        x_start = x[:n_row]
+
+        for t in range(self.num_timesteps):
+            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                t = t.to(self.device).long()
+                noise = torch.randn_like(x_start)
+                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+                diffusion_row.append(x_noisy)
+
+        log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+        if sample:
+            # get denoise row
+            with self.ema_scope("Plotting"):
+                samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+            log["samples"] = samples
+            log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.learn_logvar:
+            params = params + [self.logvar]
+        opt = torch.optim.AdamW(params, lr=lr)
+        return opt
+
+
+class LatentDiffusion(DDPM):
+    """main class"""
+
+    def __init__(self,
+                 first_stage_config,
+                 cond_stage_config,
+                 num_timesteps_cond=None,
+                 cond_stage_key="image",
+                 cond_stage_trainable=False,
+                 concat_mode=True,
+                 cond_stage_forward=None,
+                 conditioning_key=None,
+                 scale_factor=1.0,
+                 scale_by_std=False,
+                 force_null_conditioning=False,
+                 *args, **kwargs):
+        self.force_null_conditioning = force_null_conditioning
+        self.num_timesteps_cond = default(num_timesteps_cond, 1)
+        self.scale_by_std = scale_by_std
+        assert self.num_timesteps_cond <= kwargs['timesteps']
+        # for backwards compatibility after implementation of DiffusionWrapper
+        if conditioning_key is None:
+            conditioning_key = 'concat' if concat_mode else 'crossattn'
+        if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
+            conditioning_key = None
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        reset_ema = kwargs.pop("reset_ema", False)
+        reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+        ignore_keys = kwargs.pop("ignore_keys", [])
+        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+        self.concat_mode = concat_mode
+        self.cond_stage_trainable = cond_stage_trainable
+        self.cond_stage_key = cond_stage_key
+        try:
+            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+        except:
+            self.num_downs = 0
+        if not scale_by_std:
+            self.scale_factor = scale_factor
+        else:
+            self.register_buffer('scale_factor', torch.tensor(scale_factor))
+        self.instantiate_first_stage(first_stage_config)
+        self.instantiate_cond_stage(cond_stage_config)
+        self.cond_stage_forward = cond_stage_forward
+        self.clip_denoised = False
+        self.bbox_tokenizer = None
+
+        self.restarted_from_ckpt = False
+        if ckpt_path is not None:
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+            self.restarted_from_ckpt = True
+            if reset_ema:
+                assert self.use_ema
+                print(
+                    f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+                self.model_ema = LitEma(self.model)
+        if reset_num_ema_updates:
+            print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+            assert self.use_ema
+            self.model_ema.reset_num_updates()
+
+    def make_cond_schedule(self, ):
+        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+        self.cond_ids[:self.num_timesteps_cond] = ids
+
+    @rank_zero_only
+    @torch.no_grad()
+    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+        # only for very first batch
+        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+            # set rescale weight to 1./std of encodings
+            print("### USING STD-RESCALING ###")
+            x = super().get_input(batch, self.first_stage_key)
+            x = x.to(self.device)
+            encoder_posterior = self.encode_first_stage(x)
+            z = self.get_first_stage_encoding(encoder_posterior).detach()
+            del self.scale_factor
+            self.register_buffer('scale_factor', 1. / z.flatten().std())
+            print(f"setting self.scale_factor to {self.scale_factor}")
+            print("### USING STD-RESCALING ###")
+
+    def register_schedule(self,
+                          given_betas=None, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+        self.shorten_cond_schedule = self.num_timesteps_cond > 1
+        if self.shorten_cond_schedule:
+            self.make_cond_schedule()
+
+    def instantiate_first_stage(self, config):
+        model = instantiate_from_config(config)
+        self.first_stage_model = model.eval()
+        self.first_stage_model.train = disabled_train
+        for param in self.first_stage_model.parameters():
+            param.requires_grad = False
+
+    def instantiate_cond_stage(self, config):
+        if not self.cond_stage_trainable:
+            if config == "__is_first_stage__":
+                print("Using first stage also as cond stage.")
+                self.cond_stage_model = self.first_stage_model
+            elif config == "__is_unconditional__":
+                print(f"Training {self.__class__.__name__} as an unconditional model.")
+                self.cond_stage_model = None
+                # self.be_unconditional = True
+            else:
+                model = instantiate_from_config(config)
+                self.cond_stage_model = model.eval()
+                self.cond_stage_model.train = disabled_train
+                for param in self.cond_stage_model.parameters():
+                    param.requires_grad = False
+        else:
+            assert config != '__is_first_stage__'
+            assert config != '__is_unconditional__'
+            model = instantiate_from_config(config)
+            self.cond_stage_model = model
+
+    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+        denoise_row = []
+        for zd in tqdm(samples, desc=desc):
+            denoise_row.append(self.decode_first_stage(zd.to(self.device),
+                                                       force_not_quantize=force_no_decoder_quantization))
+        n_imgs_per_row = len(denoise_row)
+        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
+        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+        return denoise_grid
+
+    def get_first_stage_encoding(self, encoder_posterior):
+        if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+            z = encoder_posterior.sample()
+        elif isinstance(encoder_posterior, torch.Tensor):
+            z = encoder_posterior
+        else:
+            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+        return self.scale_factor * z
+
+    def get_learned_conditioning(self, c):
+        if self.cond_stage_forward is None:
+            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+                c = self.cond_stage_model.encode(c)
+                if isinstance(c, DiagonalGaussianDistribution):
+                    c = c.mode()
+            else:
+                c = self.cond_stage_model(c)
+        else:
+            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+        return c
+
+    def meshgrid(self, h, w):
+        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+        arr = torch.cat([y, x], dim=-1)
+        return arr
+
+    def delta_border(self, h, w):
+        """
+        :param h: height
+        :param w: width
+        :return: normalized distance to image border,
+         wtith min distance = 0 at border and max dist = 0.5 at image center
+        """
+        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+        arr = self.meshgrid(h, w) / lower_right_corner
+        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+        edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+        return edge_dist
+
+    def get_weighting(self, h, w, Ly, Lx, device):
+        weighting = self.delta_border(h, w)
+        weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+                               self.split_input_params["clip_max_weight"], )
+        weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+        if self.split_input_params["tie_braker"]:
+            L_weighting = self.delta_border(Ly, Lx)
+            L_weighting = torch.clip(L_weighting,
+                                     self.split_input_params["clip_min_tie_weight"],
+                                     self.split_input_params["clip_max_tie_weight"])
+
+            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+            weighting = weighting * L_weighting
+        return weighting
+
+    def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code
+        """
+        :param x: img of size (bs, c, h, w)
+        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+        """
+        bs, nc, h, w = x.shape
+
+        # number of crops in image
+        Ly = (h - kernel_size[0]) // stride[0] + 1
+        Lx = (w - kernel_size[1]) // stride[1] + 1
+
+        if uf == 1 and df == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+            weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+        elif uf > 1 and df == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+                                dilation=1, padding=0,
+                                stride=(stride[0] * uf, stride[1] * uf))
+            fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+            weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+        elif df > 1 and uf == 1:
+            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+            unfold = torch.nn.Unfold(**fold_params)
+
+            fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+                                dilation=1, padding=0,
+                                stride=(stride[0] // df, stride[1] // df))
+            fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+            weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+            normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap
+            weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+        else:
+            raise NotImplementedError
+
+        return fold, unfold, normalization, weighting
+
+    @torch.no_grad()
+    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+                  cond_key=None, return_original_cond=False, bs=None, return_x=False):
+        x = super().get_input(batch, k)
+        if bs is not None:
+            x = x[:bs]
+        x = x.to(self.device)
+        encoder_posterior = self.encode_first_stage(x)
+        z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+        if self.model.conditioning_key is not None and not self.force_null_conditioning:
+            if cond_key is None:
+                cond_key = self.cond_stage_key
+            if cond_key != self.first_stage_key:
+                if cond_key in ['caption', 'coordinates_bbox', "txt"]:
+                    xc = batch[cond_key]
+                elif cond_key in ['class_label', 'cls']:
+                    xc = batch
+                else:
+                    xc = super().get_input(batch, cond_key).to(self.device)
+            else:
+                xc = x
+            if not self.cond_stage_trainable or force_c_encode:
+                if isinstance(xc, dict) or isinstance(xc, list):
+                    c = self.get_learned_conditioning(xc)
+                else:
+                    c = self.get_learned_conditioning(xc.to(self.device))
+            else:
+                c = xc
+            if bs is not None:
+                c = c[:bs]
+
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                ckey = __conditioning_keys__[self.model.conditioning_key]
+                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+        else:
+            c = None
+            xc = None
+            if self.use_positional_encodings:
+                pos_x, pos_y = self.compute_latent_shifts(batch)
+                c = {'pos_x': pos_x, 'pos_y': pos_y}
+        out = [z, c]
+        if return_first_stage_outputs:
+            xrec = self.decode_first_stage(z)
+            out.extend([x, xrec])
+        if return_x:
+            out.extend([x])
+        if return_original_cond:
+            out.append(xc)
+        return out
+
+    @torch.no_grad()
+    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+        if predict_cids:
+            if z.dim() == 4:
+                z = torch.argmax(z.exp(), dim=1).long()
+            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+            z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+        z = 1. / self.scale_factor * z
+        return self.first_stage_model.decode(z)
+
+    @torch.no_grad()
+    def encode_first_stage(self, x):
+        return self.first_stage_model.encode(x)
+
+    def shared_step(self, batch, **kwargs):
+        x, c = self.get_input(batch, self.first_stage_key)
+        loss = self(x, c)
+        return loss
+
+    def forward(self, x, c, *args, **kwargs):
+        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+        if self.model.conditioning_key is not None:
+            assert c is not None
+            if self.cond_stage_trainable:
+                c = self.get_learned_conditioning(c)
+            if self.shorten_cond_schedule:  # TODO: drop this option
+                tc = self.cond_ids[t].to(self.device)
+                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+        return self.p_losses(x, c, t, *args, **kwargs)
+
+    def apply_model(self, x_noisy, t, cond, return_ids=False):
+        if isinstance(cond, dict):
+            # hybrid case, cond is expected to be a dict
+            pass
+        else:
+            if not isinstance(cond, list):
+                cond = [cond]
+            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+            cond = {key: cond}
+
+        x_recon = self.model(x_noisy, t, **cond)
+
+        if isinstance(x_recon, tuple) and not return_ids:
+            return x_recon[0]
+        else:
+            return x_recon
+
+    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+    def _prior_bpd(self, x_start):
+        """
+        Get the prior KL term for the variational lower-bound, measured in
+        bits-per-dim.
+        This term can't be optimized, as it only depends on the encoder.
+        :param x_start: the [N x C x ...] tensor of inputs.
+        :return: a batch of [N] KL values (in bits), one per batch element.
+        """
+        batch_size = x_start.shape[0]
+        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+        return mean_flat(kl_prior) / np.log(2.0)
+
+    def p_losses(self, x_start, cond, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+        model_output = self.apply_model(x_noisy, t, cond)
+
+        loss_dict = {}
+        prefix = 'train' if self.training else 'val'
+
+        if self.parameterization == "x0":
+            target = x_start
+        elif self.parameterization == "eps":
+            target = noise
+        elif self.parameterization == "v":
+            target = self.get_v(x_start, noise, t)
+        else:
+            raise NotImplementedError()
+
+        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+        logvar_t = self.logvar[t].to(self.device)
+        loss = loss_simple / torch.exp(logvar_t) + logvar_t
+        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+        if self.learn_logvar:
+            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+            loss_dict.update({'logvar': self.logvar.data.mean()})
+
+        loss = self.l_simple_weight * loss.mean()
+
+        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+        loss += (self.original_elbo_weight * loss_vlb)
+        loss_dict.update({f'{prefix}/loss': loss})
+
+        return loss, loss_dict
+
+    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+                        return_x0=False, score_corrector=None, corrector_kwargs=None):
+        t_in = t
+        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+        if score_corrector is not None:
+            assert self.parameterization == "eps"
+            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+        if return_codebook_ids:
+            model_out, logits = model_out
+
+        if self.parameterization == "eps":
+            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+        elif self.parameterization == "x0":
+            x_recon = model_out
+        else:
+            raise NotImplementedError()
+
+        if clip_denoised:
+            x_recon.clamp_(-1., 1.)
+        if quantize_denoised:
+            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+        if return_codebook_ids:
+            return model_mean, posterior_variance, posterior_log_variance, logits
+        elif return_x0:
+            return model_mean, posterior_variance, posterior_log_variance, x_recon
+        else:
+            return model_mean, posterior_variance, posterior_log_variance
+
+    @torch.no_grad()
+    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+        b, *_, device = *x.shape, x.device
+        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+                                       return_codebook_ids=return_codebook_ids,
+                                       quantize_denoised=quantize_denoised,
+                                       return_x0=return_x0,
+                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+        if return_codebook_ids:
+            raise DeprecationWarning("Support dropped.")
+            model_mean, _, model_log_variance, logits = outputs
+        elif return_x0:
+            model_mean, _, model_log_variance, x0 = outputs
+        else:
+            model_mean, _, model_log_variance = outputs
+
+        noise = noise_like(x.shape, device, repeat_noise) * temperature
+        if noise_dropout > 0.:
+            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+        # no noise when t == 0
+        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+        if return_codebook_ids:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+        if return_x0:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+        else:
+            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+    @torch.no_grad()
+    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+                              log_every_t=None):
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        timesteps = self.num_timesteps
+        if batch_size is not None:
+            b = batch_size if batch_size is not None else shape[0]
+            shape = [batch_size] + list(shape)
+        else:
+            b = batch_size = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=self.device)
+        else:
+            img = x_T
+        intermediates = []
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+            else:
+                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+                        total=timesteps) if verbose else reversed(
+            range(0, timesteps))
+        if type(temperature) == float:
+            temperature = [temperature] * timesteps
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != 'hybrid'
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img, x0_partial = self.p_sample(img, cond, ts,
+                                            clip_denoised=self.clip_denoised,
+                                            quantize_denoised=quantize_denoised, return_x0=True,
+                                            temperature=temperature[i], noise_dropout=noise_dropout,
+                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1. - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(x0_partial)
+            if callback: callback(i)
+            if img_callback: img_callback(img, i)
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_loop(self, cond, shape, return_intermediates=False,
+                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, start_T=None,
+                      log_every_t=None):
+
+        if not log_every_t:
+            log_every_t = self.log_every_t
+        device = self.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        intermediates = [img]
+        if timesteps is None:
+            timesteps = self.num_timesteps
+
+        if start_T is not None:
+            timesteps = min(timesteps, start_T)
+        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+            range(0, timesteps))
+
+        if mask is not None:
+            assert x0 is not None
+            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match
+
+        for i in iterator:
+            ts = torch.full((b,), i, device=device, dtype=torch.long)
+            if self.shorten_cond_schedule:
+                assert self.model.conditioning_key != 'hybrid'
+                tc = self.cond_ids[ts].to(cond.device)
+                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+            img = self.p_sample(img, cond, ts,
+                                clip_denoised=self.clip_denoised,
+                                quantize_denoised=quantize_denoised)
+            if mask is not None:
+                img_orig = self.q_sample(x0, ts)
+                img = img_orig * mask + (1. - mask) * img
+
+            if i % log_every_t == 0 or i == timesteps - 1:
+                intermediates.append(img)
+            if callback: callback(i)
+            if img_callback: img_callback(img, i)
+
+        if return_intermediates:
+            return img, intermediates
+        return img
+
+    @torch.no_grad()
+    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+               verbose=True, timesteps=None, quantize_denoised=False,
+               mask=None, x0=None, shape=None, **kwargs):
+        if shape is None:
+            shape = (batch_size, self.channels, self.image_size, self.image_size)
+        if cond is not None:
+            if isinstance(cond, dict):
+                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+            else:
+                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+        return self.p_sample_loop(cond,
+                                  shape,
+                                  return_intermediates=return_intermediates, x_T=x_T,
+                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+                                  mask=mask, x0=x0)
+
+    @torch.no_grad()
+    def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+        if ddim:
+            ddim_sampler = DDIMSampler(self)
+            shape = (self.channels, self.image_size, self.image_size)
+            samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+                                                         shape, cond, verbose=False, **kwargs)
+
+        else:
+            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+                                                 return_intermediates=True, **kwargs)
+
+        return samples, intermediates
+
+    @torch.no_grad()
+    def get_unconditional_conditioning(self, batch_size, null_label=None):
+        if null_label is not None:
+            xc = null_label
+            if isinstance(xc, ListConfig):
+                xc = list(xc)
+            if isinstance(xc, dict) or isinstance(xc, list):
+                c = self.get_learned_conditioning(xc)
+            else:
+                if hasattr(xc, "to"):
+                    xc = xc.to(self.device)
+                c = self.get_learned_conditioning(xc)
+        else:
+            if self.cond_stage_key in ["class_label", "cls"]:
+                xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
+                return self.get_learned_conditioning(xc)
+            else:
+                raise NotImplementedError("todo")
+        if isinstance(c, list):  # in case the encoder gives us a list
+            for i in range(len(c)):
+                c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
+        else:
+            c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+        return c
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
+                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+                   use_ema_scope=True,
+                   **kwargs):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+                                           return_first_stage_outputs=True,
+                                           force_c_encode=True,
+                                           return_original_cond=True,
+                                           bs=N)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ['class_label', "cls"]:
+                try:
+                    xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+                    log['conditioning'] = xc
+                except KeyError:
+                    # probably no "human_label" in batch
+                    pass
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                         ddim_steps=ddim_steps, eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+                    self.first_stage_model, IdentityFirstStage):
+                # also display when quantizing x0 while sampling
+                with ema_scope("Plotting Quantized Denoised"):
+                    samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                             ddim_steps=ddim_steps, eta=ddim_eta,
+                                                             quantize_denoised=True)
+                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+                    #                                      quantize_denoised=True)
+                x_samples = self.decode_first_stage(samples.to(self.device))
+                log["samples_x0_quantized"] = x_samples
+
+        if unconditional_guidance_scale > 1.0:
+            uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            if self.model.conditioning_key == "crossattn-adm":
+                uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                 ddim_steps=ddim_steps, eta=ddim_eta,
+                                                 unconditional_guidance_scale=unconditional_guidance_scale,
+                                                 unconditional_conditioning=uc,
+                                                 )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        if inpaint:
+            # make a simple center square
+            b, h, w = z.shape[0], z.shape[2], z.shape[3]
+            mask = torch.ones(N, h, w).to(self.device)
+            # zeros will be filled in
+            mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+            mask = mask[:, None, ...]
+            with ema_scope("Plotting Inpaint"):
+                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+            x_samples = self.decode_first_stage(samples.to(self.device))
+            log["samples_inpainting"] = x_samples
+            log["mask"] = mask
+
+            # outpaint
+            mask = 1. - mask
+            with ema_scope("Plotting Outpaint"):
+                samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+                                             ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+            x_samples = self.decode_first_stage(samples.to(self.device))
+            log["samples_outpainting"] = x_samples
+
+        if plot_progressive_rows:
+            with ema_scope("Plotting Progressives"):
+                img, progressives = self.progressive_denoising(c,
+                                                               shape=(self.channels, self.image_size, self.image_size),
+                                                               batch_size=N)
+            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+            log["progressive_row"] = prog_row
+
+        if return_keys:
+            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+                return log
+            else:
+                return {key: log[key] for key in return_keys}
+        return log
+
+    def configure_optimizers(self):
+        lr = self.learning_rate
+        params = list(self.model.parameters())
+        if self.cond_stage_trainable:
+            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+            params = params + list(self.cond_stage_model.parameters())
+        if self.learn_logvar:
+            print('Diffusion model optimizing logvar')
+            params.append(self.logvar)
+        opt = torch.optim.AdamW(params, lr=lr)
+        if self.use_scheduler:
+            assert 'target' in self.scheduler_config
+            scheduler = instantiate_from_config(self.scheduler_config)
+
+            print("Setting up LambdaLR scheduler...")
+            scheduler = [
+                {
+                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+                    'interval': 'step',
+                    'frequency': 1
+                }]
+            return [opt], scheduler
+        return opt
+
+    @torch.no_grad()
+    def to_rgb(self, x):
+        x = x.float()
+        if not hasattr(self, "colorize"):
+            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+        x = nn.functional.conv2d(x, weight=self.colorize)
+        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+        return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+    def __init__(self, diff_model_config, conditioning_key):
+        super().__init__()
+        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
+        self.diffusion_model = instantiate_from_config(diff_model_config)
+        self.conditioning_key = conditioning_key
+        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
+
+    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
+        if self.conditioning_key is None:
+            out = self.diffusion_model(x, t)
+        elif self.conditioning_key == 'concat':
+            xc = torch.cat([x] + c_concat, dim=1)
+            out = self.diffusion_model(xc, t)
+        elif self.conditioning_key == 'crossattn':
+            if not self.sequential_cross_attn:
+                cc = torch.cat(c_crossattn, 1)
+            else:
+                cc = c_crossattn
+            out = self.diffusion_model(x, t, context=cc)
+        elif self.conditioning_key == 'hybrid':
+            xc = torch.cat([x] + c_concat, dim=1)
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(xc, t, context=cc)
+        elif self.conditioning_key == 'hybrid-adm':
+            assert c_adm is not None
+            xc = torch.cat([x] + c_concat, dim=1)
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+        elif self.conditioning_key == 'crossattn-adm':
+            assert c_adm is not None
+            cc = torch.cat(c_crossattn, 1)
+            out = self.diffusion_model(x, t, context=cc, y=c_adm)
+        elif self.conditioning_key == 'adm':
+            cc = c_crossattn[0]
+            out = self.diffusion_model(x, t, y=cc)
+        else:
+            raise NotImplementedError()
+
+        return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+    def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
+        super().__init__(*args, **kwargs)
+        # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+        assert not self.cond_stage_trainable
+        self.instantiate_low_stage(low_scale_config)
+        self.low_scale_key = low_scale_key
+        self.noise_level_key = noise_level_key
+
+    def instantiate_low_stage(self, config):
+        model = instantiate_from_config(config)
+        self.low_scale_model = model.eval()
+        self.low_scale_model.train = disabled_train
+        for param in self.low_scale_model.parameters():
+            param.requires_grad = False
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+        if not log_mode:
+            z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+        else:
+            z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                                  force_c_encode=True, return_original_cond=True, bs=bs)
+        x_low = batch[self.low_scale_key][:bs]
+        x_low = rearrange(x_low, 'b h w c -> b c h w')
+        x_low = x_low.to(memory_format=torch.contiguous_format).float()
+        zx, noise_level = self.low_scale_model(x_low)
+        if self.noise_level_key is not None:
+            # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+            raise NotImplementedError('TODO')
+
+        all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+        if log_mode:
+            # TODO: maybe disable if too expensive
+            x_low_rec = self.low_scale_model.decode(zx)
+            return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+                   plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+                   unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+                   **kwargs):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+                                                                          log_mode=True)
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        log["x_lr"] = x_low
+        log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ['class_label', 'cls']:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+                log['conditioning'] = xc
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                         ddim_steps=ddim_steps, eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            # TODO explore better "unconditional" choices for the other keys
+            # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+            uc = dict()
+            for k in c:
+                if k == "c_crossattn":
+                    assert isinstance(c[k], list) and len(c[k]) == 1
+                    uc[k] = [uc_tmp]
+                elif k == "c_adm":  # todo: only run with text-based guidance?
+                    assert isinstance(c[k], torch.Tensor)
+                    #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+                    uc[k] = c[k]
+                elif isinstance(c[k], list):
+                    uc[k] = [c[k][i] for i in range(len(c[k]))]
+                else:
+                    uc[k] = c[k]
+
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+                                                 ddim_steps=ddim_steps, eta=ddim_eta,
+                                                 unconditional_guidance_scale=unconditional_guidance_scale,
+                                                 unconditional_conditioning=uc,
+                                                 )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        if plot_progressive_rows:
+            with ema_scope("Plotting Progressives"):
+                img, progressives = self.progressive_denoising(c,
+                                                               shape=(self.channels, self.image_size, self.image_size),
+                                                               batch_size=N)
+            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+            log["progressive_row"] = prog_row
+
+        return log
+
+
+class LatentFinetuneDiffusion(LatentDiffusion):
+    """
+         Basis for different finetunas, such as inpainting or depth2image
+         To disable finetuning mode, set finetune_keys to None
+    """
+
+    def __init__(self,
+                 concat_keys: tuple,
+                 finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+                                "model_ema.diffusion_modelinput_blocks00weight"
+                                ),
+                 keep_finetune_dims=4,
+                 # if model was trained without concat mode before and we would like to keep these channels
+                 c_concat_log_start=None,  # to log reconstruction of c_concat codes
+                 c_concat_log_end=None,
+                 *args, **kwargs
+                 ):
+        ckpt_path = kwargs.pop("ckpt_path", None)
+        ignore_keys = kwargs.pop("ignore_keys", list())
+        super().__init__(*args, **kwargs)
+        self.finetune_keys = finetune_keys
+        self.concat_keys = concat_keys
+        self.keep_dims = keep_finetune_dims
+        self.c_concat_log_start = c_concat_log_start
+        self.c_concat_log_end = c_concat_log_end
+        if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
+        if exists(ckpt_path):
+            self.init_from_ckpt(ckpt_path, ignore_keys)
+
+    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+        sd = torch.load(path, map_location="cpu")
+        if "state_dict" in list(sd.keys()):
+            sd = sd["state_dict"]
+        keys = list(sd.keys())
+        for k in keys:
+            for ik in ignore_keys:
+                if k.startswith(ik):
+                    print("Deleting key {} from state_dict.".format(k))
+                    del sd[k]
+
+            # make it explicit, finetune by including extra input channels
+            if exists(self.finetune_keys) and k in self.finetune_keys:
+                new_entry = None
+                for name, param in self.named_parameters():
+                    if name in self.finetune_keys:
+                        print(
+                            f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
+                        new_entry = torch.zeros_like(param)  # zero init
+                assert exists(new_entry), 'did not find matching parameter to modify'
+                new_entry[:, :self.keep_dims, ...] = sd[k]
+                sd[k] = new_entry
+
+        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+            sd, strict=False)
+        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+        if len(missing) > 0:
+            print(f"Missing Keys: {missing}")
+        if len(unexpected) > 0:
+            print(f"Unexpected Keys: {unexpected}")
+
+    @torch.no_grad()
+    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+                   plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+                   use_ema_scope=True,
+                   **kwargs):
+        ema_scope = self.ema_scope if use_ema_scope else nullcontext
+        use_ddim = ddim_steps is not None
+
+        log = dict()
+        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
+        c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+        N = min(x.shape[0], N)
+        n_row = min(x.shape[0], n_row)
+        log["inputs"] = x
+        log["reconstruction"] = xrec
+        if self.model.conditioning_key is not None:
+            if hasattr(self.cond_stage_model, "decode"):
+                xc = self.cond_stage_model.decode(c)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ["caption", "txt"]:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+                log["conditioning"] = xc
+            elif self.cond_stage_key in ['class_label', 'cls']:
+                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+                log['conditioning'] = xc
+            elif isimage(xc):
+                log["conditioning"] = xc
+            if ismap(xc):
+                log["original_conditioning"] = self.to_rgb(xc)
+
+        if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+            log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
+
+        if plot_diffusion_rows:
+            # get diffusion row
+            diffusion_row = list()
+            z_start = z[:n_row]
+            for t in range(self.num_timesteps):
+                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+                    t = t.to(self.device).long()
+                    noise = torch.randn_like(z_start)
+                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+                    diffusion_row.append(self.decode_first_stage(z_noisy))
+
+            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
+            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+            log["diffusion_row"] = diffusion_grid
+
+        if sample:
+            # get denoise row
+            with ema_scope("Sampling"):
+                samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                                                         batch_size=N, ddim=use_ddim,
+                                                         ddim_steps=ddim_steps, eta=ddim_eta)
+                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+            x_samples = self.decode_first_stage(samples)
+            log["samples"] = x_samples
+            if plot_denoise_rows:
+                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+                log["denoise_row"] = denoise_grid
+
+        if unconditional_guidance_scale > 1.0:
+            uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+            uc_cat = c_cat
+            uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+            with ema_scope("Sampling with classifier-free guidance"):
+                samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+                                                 batch_size=N, ddim=use_ddim,
+                                                 ddim_steps=ddim_steps, eta=ddim_eta,
+                                                 unconditional_guidance_scale=unconditional_guidance_scale,
+                                                 unconditional_conditioning=uc_full,
+                                                 )
+                x_samples_cfg = self.decode_first_stage(samples_cfg)
+                log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+        return log
+
+
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+    """
+    can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+    e.g. mask as concat and text via cross-attn.
+    To disable finetuning mode, set finetune_keys to None
+     """
+
+    def __init__(self,
+                 concat_keys=("mask", "masked_image"),
+                 masked_image_key="masked_image",
+                 *args, **kwargs
+                 ):
+        super().__init__(concat_keys, *args, **kwargs)
+        self.masked_image_key = masked_image_key
+        assert self.masked_image_key in concat_keys
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+        # note: restricted to non-trainable encoders currently
+        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
+        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                              force_c_encode=True, return_original_cond=True, bs=bs)
+
+        assert exists(self.concat_keys)
+        c_cat = list()
+        for ck in self.concat_keys:
+            cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            bchw = z.shape
+            if ck != self.masked_image_key:
+                cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+            else:
+                cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+        log["masked_image"] = rearrange(args[0]["masked_image"],
+                                        'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+        return log
+
+
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+    """
+    condition on monocular depth estimation
+    """
+
+    def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+        super().__init__(concat_keys=concat_keys, *args, **kwargs)
+        self.depth_model = instantiate_from_config(depth_stage_config)
+        self.depth_stage_key = concat_keys[0]
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+        # note: restricted to non-trainable encoders currently
+        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
+        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                              force_c_encode=True, return_original_cond=True, bs=bs)
+
+        assert exists(self.concat_keys)
+        assert len(self.concat_keys) == 1
+        c_cat = list()
+        for ck in self.concat_keys:
+            cc = batch[ck]
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            cc = self.depth_model(cc)
+            cc = torch.nn.functional.interpolate(
+                cc,
+                size=z.shape[2:],
+                mode="bicubic",
+                align_corners=False,
+            )
+
+            depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+                                                                                           keepdim=True)
+            cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super().log_images(*args, **kwargs)
+        depth = self.depth_model(args[0][self.depth_stage_key])
+        depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
+                               torch.amax(depth, dim=[1, 2, 3], keepdim=True)
+        log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
+        return log
+
+
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+    """
+        condition on low-res image (and optionally on some spatial noise augmentation)
+    """
+    def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
+                 low_scale_config=None, low_scale_key=None, *args, **kwargs):
+        super().__init__(concat_keys=concat_keys, *args, **kwargs)
+        self.reshuffle_patch_size = reshuffle_patch_size
+        self.low_scale_model = None
+        if low_scale_config is not None:
+            print("Initializing a low-scale model")
+            assert exists(low_scale_key)
+            self.instantiate_low_stage(low_scale_config)
+            self.low_scale_key = low_scale_key
+
+    def instantiate_low_stage(self, config):
+        model = instantiate_from_config(config)
+        self.low_scale_model = model.eval()
+        self.low_scale_model.train = disabled_train
+        for param in self.low_scale_model.parameters():
+            param.requires_grad = False
+
+    @torch.no_grad()
+    def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+        # note: restricted to non-trainable encoders currently
+        assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
+        z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+                                              force_c_encode=True, return_original_cond=True, bs=bs)
+
+        assert exists(self.concat_keys)
+        assert len(self.concat_keys) == 1
+        # optionally make spatial noise_level here
+        c_cat = list()
+        noise_level = None
+        for ck in self.concat_keys:
+            cc = batch[ck]
+            cc = rearrange(cc, 'b h w c -> b c h w')
+            if exists(self.reshuffle_patch_size):
+                assert isinstance(self.reshuffle_patch_size, int)
+                cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+                               p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
+            if bs is not None:
+                cc = cc[:bs]
+                cc = cc.to(self.device)
+            if exists(self.low_scale_model) and ck == self.low_scale_key:
+                cc, noise_level = self.low_scale_model(cc)
+            c_cat.append(cc)
+        c_cat = torch.cat(c_cat, dim=1)
+        if exists(noise_level):
+            all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+        else:
+            all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+        if return_first_stage_outputs:
+            return z, all_conds, x, xrec, xc
+        return z, all_conds
+
+    @torch.no_grad()
+    def log_images(self, *args, **kwargs):
+        log = super().log_images(*args, **kwargs)
+        log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
+        return log
diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1154 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+    def __init__(
+            self,
+            schedule='discrete',
+            betas=None,
+            alphas_cumprod=None,
+            continuous_beta_0=0.1,
+            continuous_beta_1=20.,
+    ):
+        """Create a wrapper class for the forward SDE (VP type).
+        ***
+        Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+                We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+        ***
+        The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+        We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+        Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+            log_alpha_t = self.marginal_log_mean_coeff(t)
+            sigma_t = self.marginal_std(t)
+            lambda_t = self.marginal_lambda(t)
+        Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+            t = self.inverse_lambda(lambda_t)
+        ===============================================================
+        We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+        1. For discrete-time DPMs:
+            For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+                t_i = (i + 1) / N
+            e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+            We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+            Args:
+                betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+                alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+            Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+            **Important**:  Please pay special attention for the args for `alphas_cumprod`:
+                The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+                    q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+                Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+                    alpha_{t_n} = \sqrt{\hat{alpha_n}},
+                and
+                    log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+        2. For continuous-time DPMs:
+            We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+            schedule are the default settings in DDPM and improved-DDPM:
+            Args:
+                beta_min: A `float` number. The smallest beta for the linear schedule.
+                beta_max: A `float` number. The largest beta for the linear schedule.
+                cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+                cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+                T: A `float` number. The ending time of the forward process.
+        ===============================================================
+        Args:
+            schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+                    'linear' or 'cosine' for continuous-time DPMs.
+        Returns:
+            A wrapper object of the forward SDE (VP type).
+
+        ===============================================================
+        Example:
+        # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+        >>> ns = NoiseScheduleVP('discrete', betas=betas)
+        # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+        >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+        # For continuous-time DPMs (VPSDE), linear schedule:
+        >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+        """
+
+        if schedule not in ['discrete', 'linear', 'cosine']:
+            raise ValueError(
+                "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+                    schedule))
+
+        self.schedule = schedule
+        if schedule == 'discrete':
+            if betas is not None:
+                log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+            else:
+                assert alphas_cumprod is not None
+                log_alphas = 0.5 * torch.log(alphas_cumprod)
+            self.total_N = len(log_alphas)
+            self.T = 1.
+            self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+            self.log_alpha_array = log_alphas.reshape((1, -1,))
+        else:
+            self.total_N = 1000
+            self.beta_0 = continuous_beta_0
+            self.beta_1 = continuous_beta_1
+            self.cosine_s = 0.008
+            self.cosine_beta_max = 999.
+            self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+                        1. + self.cosine_s) / math.pi - self.cosine_s
+            self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+            self.schedule = schedule
+            if schedule == 'cosine':
+                # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+                # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+                self.T = 0.9946
+            else:
+                self.T = 1.
+
+    def marginal_log_mean_coeff(self, t):
+        """
+        Compute log(alpha_t) of a given continuous-time label t in [0, T].
+        """
+        if self.schedule == 'discrete':
+            return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+                                  self.log_alpha_array.to(t.device)).reshape((-1))
+        elif self.schedule == 'linear':
+            return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+        elif self.schedule == 'cosine':
+            log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+            log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+            return log_alpha_t
+
+    def marginal_alpha(self, t):
+        """
+        Compute alpha_t of a given continuous-time label t in [0, T].
+        """
+        return torch.exp(self.marginal_log_mean_coeff(t))
+
+    def marginal_std(self, t):
+        """
+        Compute sigma_t of a given continuous-time label t in [0, T].
+        """
+        return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+    def marginal_lambda(self, t):
+        """
+        Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+        """
+        log_mean_coeff = self.marginal_log_mean_coeff(t)
+        log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+        return log_mean_coeff - log_std
+
+    def inverse_lambda(self, lamb):
+        """
+        Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+        """
+        if self.schedule == 'linear':
+            tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+            Delta = self.beta_0 ** 2 + tmp
+            return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+        elif self.schedule == 'discrete':
+            log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+            t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+                               torch.flip(self.t_array.to(lamb.device), [1]))
+            return t.reshape((-1,))
+        else:
+            log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+            t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+                        1. + self.cosine_s) / math.pi - self.cosine_s
+            t = t_fn(log_alpha)
+            return t
+
+
+def model_wrapper(
+        model,
+        noise_schedule,
+        model_type="noise",
+        model_kwargs={},
+        guidance_type="uncond",
+        condition=None,
+        unconditional_condition=None,
+        guidance_scale=1.,
+        classifier_fn=None,
+        classifier_kwargs={},
+):
+    """Create a wrapper function for the noise prediction model.
+    DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+    firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+    We support four types of the diffusion model by setting `model_type`:
+        1. "noise": noise prediction model. (Trained by predicting noise).
+        2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+        3. "v": velocity prediction model. (Trained by predicting the velocity).
+            The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+            [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+                arXiv preprint arXiv:2202.00512 (2022).
+            [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+                arXiv preprint arXiv:2210.02303 (2022).
+
+        4. "score": marginal score function. (Trained by denoising score matching).
+            Note that the score function and the noise prediction model follows a simple relationship:
+            ```
+                noise(x_t, t) = -sigma_t * score(x_t, t)
+            ```
+    We support three types of guided sampling by DPMs by setting `guidance_type`:
+        1. "uncond": unconditional sampling by DPMs.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+            ``
+        2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+            ``
+            The input `classifier_fn` has the following format:
+            ``
+                classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+            ``
+            [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+                in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+        3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+            The input `model` has the following format:
+            ``
+                model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+            ``
+            And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+            [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+                arXiv preprint arXiv:2207.12598 (2022).
+
+    The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+    or continuous-time labels (i.e. epsilon to T).
+    We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+    ``
+        def model_fn(x, t_continuous) -> noise:
+            t_input = get_model_input_time(t_continuous)
+            return noise_pred(model, x, t_input, **model_kwargs)
+    ``
+    where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+    ===============================================================
+    Args:
+        model: A diffusion model with the corresponding format described above.
+        noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+        model_type: A `str`. The parameterization type of the diffusion model.
+                    "noise" or "x_start" or "v" or "score".
+        model_kwargs: A `dict`. A dict for the other inputs of the model function.
+        guidance_type: A `str`. The type of the guidance for sampling.
+                    "uncond" or "classifier" or "classifier-free".
+        condition: A pytorch tensor. The condition for the guided sampling.
+                    Only used for "classifier" or "classifier-free" guidance type.
+        unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+                    Only used for "classifier-free" guidance type.
+        guidance_scale: A `float`. The scale for the guided sampling.
+        classifier_fn: A classifier function. Only used for the classifier guidance.
+        classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+    Returns:
+        A noise prediction model that accepts the noised data and the continuous time as the inputs.
+    """
+
+    def get_model_input_time(t_continuous):
+        """
+        Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+        For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+        For continuous-time DPMs, we just use `t_continuous`.
+        """
+        if noise_schedule.schedule == 'discrete':
+            return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+        else:
+            return t_continuous
+
+    def noise_pred_fn(x, t_continuous, cond=None):
+        if t_continuous.reshape((-1,)).shape[0] == 1:
+            t_continuous = t_continuous.expand((x.shape[0]))
+        t_input = get_model_input_time(t_continuous)
+        if cond is None:
+            output = model(x, t_input, **model_kwargs)
+        else:
+            output = model(x, t_input, cond, **model_kwargs)
+        if model_type == "noise":
+            return output
+        elif model_type == "x_start":
+            alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+        elif model_type == "v":
+            alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+        elif model_type == "score":
+            sigma_t = noise_schedule.marginal_std(t_continuous)
+            dims = x.dim()
+            return -expand_dims(sigma_t, dims) * output
+
+    def cond_grad_fn(x, t_input):
+        """
+        Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+        """
+        with torch.enable_grad():
+            x_in = x.detach().requires_grad_(True)
+            log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+            return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+    def model_fn(x, t_continuous):
+        """
+        The noise predicition model function that is used for DPM-Solver.
+        """
+        if t_continuous.reshape((-1,)).shape[0] == 1:
+            t_continuous = t_continuous.expand((x.shape[0]))
+        if guidance_type == "uncond":
+            return noise_pred_fn(x, t_continuous)
+        elif guidance_type == "classifier":
+            assert classifier_fn is not None
+            t_input = get_model_input_time(t_continuous)
+            cond_grad = cond_grad_fn(x, t_input)
+            sigma_t = noise_schedule.marginal_std(t_continuous)
+            noise = noise_pred_fn(x, t_continuous)
+            return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+        elif guidance_type == "classifier-free":
+            if guidance_scale == 1. or unconditional_condition is None:
+                return noise_pred_fn(x, t_continuous, cond=condition)
+            else:
+                x_in = torch.cat([x] * 2)
+                t_in = torch.cat([t_continuous] * 2)
+                c_in = torch.cat([unconditional_condition, condition])
+                noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+                return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+    assert model_type in ["noise", "x_start", "v"]
+    assert guidance_type in ["uncond", "classifier", "classifier-free"]
+    return model_fn
+
+
+class DPM_Solver:
+    def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+        """Construct a DPM-Solver.
+        We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+        If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+        If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+            In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+            The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+        Args:
+            model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+                ``
+                def model_fn(x, t_continuous):
+                    return noise
+                ``
+            noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+            predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+            thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+            max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+        [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+        """
+        self.model = model_fn
+        self.noise_schedule = noise_schedule
+        self.predict_x0 = predict_x0
+        self.thresholding = thresholding
+        self.max_val = max_val
+
+    def noise_prediction_fn(self, x, t):
+        """
+        Return the noise prediction model.
+        """
+        return self.model(x, t)
+
+    def data_prediction_fn(self, x, t):
+        """
+        Return the data prediction model (with thresholding).
+        """
+        noise = self.noise_prediction_fn(x, t)
+        dims = x.dim()
+        alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+        x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+        if self.thresholding:
+            p = 0.995  # A hyperparameter in the paper of "Imagen" [1].
+            s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+            s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+            x0 = torch.clamp(x0, -s, s) / s
+        return x0
+
+    def model_fn(self, x, t):
+        """
+        Convert the model to the noise prediction model or the data prediction model.
+        """
+        if self.predict_x0:
+            return self.data_prediction_fn(x, t)
+        else:
+            return self.noise_prediction_fn(x, t)
+
+    def get_time_steps(self, skip_type, t_T, t_0, N, device):
+        """Compute the intermediate time steps for sampling.
+        Args:
+            skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+                - 'logSNR': uniform logSNR for the time steps.
+                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            N: A `int`. The total number of the spacing of the time steps.
+            device: A torch device.
+        Returns:
+            A pytorch tensor of the time steps, with the shape (N + 1,).
+        """
+        if skip_type == 'logSNR':
+            lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+            lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+            logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+            return self.noise_schedule.inverse_lambda(logSNR_steps)
+        elif skip_type == 'time_uniform':
+            return torch.linspace(t_T, t_0, N + 1).to(device)
+        elif skip_type == 'time_quadratic':
+            t_order = 2
+            t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+            return t
+        else:
+            raise ValueError(
+                "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+    def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+        """
+        Get the order of each step for sampling by the singlestep DPM-Solver.
+        We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+        Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+            - If order == 1:
+                We take `steps` of DPM-Solver-1 (i.e. DDIM).
+            - If order == 2:
+                - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+                - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+                - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+            - If order == 3:
+                - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+                - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+                - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+                - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+        ============================================
+        Args:
+            order: A `int`. The max order for the solver (2 or 3).
+            steps: A `int`. The total number of function evaluations (NFE).
+            skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+                - 'logSNR': uniform logSNR for the time steps.
+                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            device: A torch device.
+        Returns:
+            orders: A list of the solver order of each step.
+        """
+        if order == 3:
+            K = steps // 3 + 1
+            if steps % 3 == 0:
+                orders = [3, ] * (K - 2) + [2, 1]
+            elif steps % 3 == 1:
+                orders = [3, ] * (K - 1) + [1]
+            else:
+                orders = [3, ] * (K - 1) + [2]
+        elif order == 2:
+            if steps % 2 == 0:
+                K = steps // 2
+                orders = [2, ] * K
+            else:
+                K = steps // 2 + 1
+                orders = [2, ] * (K - 1) + [1]
+        elif order == 1:
+            K = 1
+            orders = [1, ] * steps
+        else:
+            raise ValueError("'order' must be '1' or '2' or '3'.")
+        if skip_type == 'logSNR':
+            # To reproduce the results in DPM-Solver paper
+            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+        else:
+            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+                torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+        return timesteps_outer, orders
+
+    def denoise_to_zero_fn(self, x, s):
+        """
+        Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+        """
+        return self.data_prediction_fn(x, s)
+
+    def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+        """
+        DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+        sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_1 = torch.expm1(-h)
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_t = (
+                    expand_dims(sigma_t / sigma_s, dims) * x
+                    - expand_dims(alpha_t * phi_1, dims) * model_s
+            )
+            if return_intermediate:
+                return x_t, {'model_s': model_s}
+            else:
+                return x_t
+        else:
+            phi_1 = torch.expm1(h)
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                    - expand_dims(sigma_t * phi_1, dims) * model_s
+            )
+            if return_intermediate:
+                return x_t, {'model_s': model_s}
+            else:
+                return x_t
+
+    def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+                                            solver_type='dpm_solver'):
+        """
+        Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            r1: A `float`. The hyperparameter of the second-order solver.
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ['dpm_solver', 'taylor']:
+            raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+        if r1 is None:
+            r1 = 0.5
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        lambda_s1 = lambda_s + r1 * h
+        s1 = ns.inverse_lambda(lambda_s1)
+        log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+            s1), ns.marginal_log_mean_coeff(t)
+        sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+        alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_11 = torch.expm1(-r1 * h)
+            phi_1 = torch.expm1(-h)
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_s1 = (
+                    expand_dims(sigma_s1 / sigma_s, dims) * x
+                    - expand_dims(alpha_s1 * phi_11, dims) * model_s
+            )
+            model_s1 = self.model_fn(x_s1, s1)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+                                    model_s1 - model_s)
+                )
+        else:
+            phi_11 = torch.expm1(r1 * h)
+            phi_1 = torch.expm1(h)
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            x_s1 = (
+                    expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+                    - expand_dims(sigma_s1 * phi_11, dims) * model_s
+            )
+            model_s1 = self.model_fn(x_s1, s1)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+                )
+        if return_intermediate:
+            return x_t, {'model_s': model_s, 'model_s1': model_s1}
+        else:
+            return x_t
+
+    def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+                                           return_intermediate=False, solver_type='dpm_solver'):
+        """
+        Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            r1: A `float`. The hyperparameter of the third-order solver.
+            r2: A `float`. The hyperparameter of the third-order solver.
+            model_s: A pytorch tensor. The model function evaluated at time `s`.
+                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+            model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+                If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ['dpm_solver', 'taylor']:
+            raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+        if r1 is None:
+            r1 = 1. / 3.
+        if r2 is None:
+            r2 = 2. / 3.
+        ns = self.noise_schedule
+        dims = x.dim()
+        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+        h = lambda_t - lambda_s
+        lambda_s1 = lambda_s + r1 * h
+        lambda_s2 = lambda_s + r2 * h
+        s1 = ns.inverse_lambda(lambda_s1)
+        s2 = ns.inverse_lambda(lambda_s2)
+        log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+            s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+        sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+            s2), ns.marginal_std(t)
+        alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+        if self.predict_x0:
+            phi_11 = torch.expm1(-r1 * h)
+            phi_12 = torch.expm1(-r2 * h)
+            phi_1 = torch.expm1(-h)
+            phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+            phi_2 = phi_1 / h + 1.
+            phi_3 = phi_2 / h - 0.5
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            if model_s1 is None:
+                x_s1 = (
+                        expand_dims(sigma_s1 / sigma_s, dims) * x
+                        - expand_dims(alpha_s1 * phi_11, dims) * model_s
+                )
+                model_s1 = self.model_fn(x_s1, s1)
+            x_s2 = (
+                    expand_dims(sigma_s2 / sigma_s, dims) * x
+                    - expand_dims(alpha_s2 * phi_12, dims) * model_s
+                    + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+            )
+            model_s2 = self.model_fn(x_s2, s2)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+                )
+            elif solver_type == 'taylor':
+                D1_0 = (1. / r1) * (model_s1 - model_s)
+                D1_1 = (1. / r2) * (model_s2 - model_s)
+                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+                D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+                x_t = (
+                        expand_dims(sigma_t / sigma_s, dims) * x
+                        - expand_dims(alpha_t * phi_1, dims) * model_s
+                        + expand_dims(alpha_t * phi_2, dims) * D1
+                        - expand_dims(alpha_t * phi_3, dims) * D2
+                )
+        else:
+            phi_11 = torch.expm1(r1 * h)
+            phi_12 = torch.expm1(r2 * h)
+            phi_1 = torch.expm1(h)
+            phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+            phi_2 = phi_1 / h - 1.
+            phi_3 = phi_2 / h - 0.5
+
+            if model_s is None:
+                model_s = self.model_fn(x, s)
+            if model_s1 is None:
+                x_s1 = (
+                        expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+                        - expand_dims(sigma_s1 * phi_11, dims) * model_s
+                )
+                model_s1 = self.model_fn(x_s1, s1)
+            x_s2 = (
+                    expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+                    - expand_dims(sigma_s2 * phi_12, dims) * model_s
+                    - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+            )
+            model_s2 = self.model_fn(x_s2, s2)
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+                )
+            elif solver_type == 'taylor':
+                D1_0 = (1. / r1) * (model_s1 - model_s)
+                D1_1 = (1. / r2) * (model_s2 - model_s)
+                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+                D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+                        - expand_dims(sigma_t * phi_1, dims) * model_s
+                        - expand_dims(sigma_t * phi_2, dims) * D1
+                        - expand_dims(sigma_t * phi_3, dims) * D2
+                )
+
+        if return_intermediate:
+            return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+        else:
+            return x_t
+
+    def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+        """
+        Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if solver_type not in ['dpm_solver', 'taylor']:
+            raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+        ns = self.noise_schedule
+        dims = x.dim()
+        model_prev_1, model_prev_0 = model_prev_list
+        t_prev_1, t_prev_0 = t_prev_list
+        lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+            t_prev_0), ns.marginal_lambda(t)
+        log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        h_0 = lambda_prev_0 - lambda_prev_1
+        h = lambda_t - lambda_prev_0
+        r0 = h_0 / h
+        D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+        if self.predict_x0:
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(sigma_t / sigma_prev_0, dims) * x
+                        - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+                        - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(sigma_t / sigma_prev_0, dims) * x
+                        - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+                        + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+                )
+        else:
+            if solver_type == 'dpm_solver':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                        - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+                        - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+                )
+            elif solver_type == 'taylor':
+                x_t = (
+                        expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                        - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+                        - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+                )
+        return x_t
+
+    def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+        """
+        Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        ns = self.noise_schedule
+        dims = x.dim()
+        model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+        t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+        lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+            t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+        log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+        alpha_t = torch.exp(log_alpha_t)
+
+        h_1 = lambda_prev_1 - lambda_prev_2
+        h_0 = lambda_prev_0 - lambda_prev_1
+        h = lambda_t - lambda_prev_0
+        r0, r1 = h_0 / h, h_1 / h
+        D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+        D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+        D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+        D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+        if self.predict_x0:
+            x_t = (
+                    expand_dims(sigma_t / sigma_prev_0, dims) * x
+                    - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+                    + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+                    - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+            )
+        else:
+            x_t = (
+                    expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+                    - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+                    - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+                    - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+            )
+        return x_t
+
+    def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+                                     r2=None):
+        """
+        Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+            r1: A `float`. The hyperparameter of the second-order or third-order solver.
+            r2: A `float`. The hyperparameter of the third-order solver.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if order == 1:
+            return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+        elif order == 2:
+            return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+                                                            solver_type=solver_type, r1=r1)
+        elif order == 3:
+            return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+                                                           solver_type=solver_type, r1=r1, r2=r2)
+        else:
+            raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+    def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+        """
+        Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+        Args:
+            x: A pytorch tensor. The initial value at time `s`.
+            model_prev_list: A list of pytorch tensor. The previous computed model values.
+            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_t: A pytorch tensor. The approximated solution at time `t`.
+        """
+        if order == 1:
+            return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+        elif order == 2:
+            return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+        elif order == 3:
+            return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+        else:
+            raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+    def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+                            solver_type='dpm_solver'):
+        """
+        The adaptive step size solver based on singlestep DPM-Solver.
+        Args:
+            x: A pytorch tensor. The initial value at time `t_T`.
+            order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+            t_T: A `float`. The starting time of the sampling (default is T).
+            t_0: A `float`. The ending time of the sampling (default is epsilon).
+            h_init: A `float`. The initial step size (for logSNR).
+            atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+            rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+            theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+            t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+                current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+        Returns:
+            x_0: A pytorch tensor. The approximated solution at time `t_0`.
+        [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+        """
+        ns = self.noise_schedule
+        s = t_T * torch.ones((x.shape[0],)).to(x)
+        lambda_s = ns.marginal_lambda(s)
+        lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+        h = h_init * torch.ones_like(s).to(x)
+        x_prev = x
+        nfe = 0
+        if order == 2:
+            r1 = 0.5
+            lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+            higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+                                                                                               solver_type=solver_type,
+                                                                                               **kwargs)
+        elif order == 3:
+            r1, r2 = 1. / 3., 2. / 3.
+            lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+                                                                                    return_intermediate=True,
+                                                                                    solver_type=solver_type)
+            higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+                                                                                              solver_type=solver_type,
+                                                                                              **kwargs)
+        else:
+            raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+        while torch.abs((s - t_0)).mean() > t_err:
+            t = ns.inverse_lambda(lambda_s + h)
+            x_lower, lower_noise_kwargs = lower_update(x, s, t)
+            x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+            delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+            norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+            E = norm_fn((x_higher - x_lower) / delta).max()
+            if torch.all(E <= 1.):
+                x = x_higher
+                s = t
+                x_prev = x_lower
+                lambda_s = ns.marginal_lambda(s)
+            h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+            nfe += order
+        print('adaptive solver nfe', nfe)
+        return x
+
+    def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+               method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+               atol=0.0078, rtol=0.05,
+               ):
+        """
+        Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+        =====================================================
+        We support the following algorithms for both noise prediction model and data prediction model:
+            - 'singlestep':
+                Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+                We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+                The total number of function evaluations (NFE) == `steps`.
+                Given a fixed NFE == `steps`, the sampling procedure is:
+                    - If `order` == 1:
+                        - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+                    - If `order` == 2:
+                        - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+                        - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+                        - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+                    - If `order` == 3:
+                        - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+                        - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+                        - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+                        - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+            - 'multistep':
+                Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+                We initialize the first `order` values by lower order multistep solvers.
+                Given a fixed NFE == `steps`, the sampling procedure is:
+                    Denote K = steps.
+                    - If `order` == 1:
+                        - We use K steps of DPM-Solver-1 (i.e. DDIM).
+                    - If `order` == 2:
+                        - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+                    - If `order` == 3:
+                        - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+            - 'singlestep_fixed':
+                Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+                We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+            - 'adaptive':
+                Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+                We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+                You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+                (NFE) and the sample quality.
+                    - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+                    - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+        =====================================================
+        Some advices for choosing the algorithm:
+            - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+                Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+                e.g.
+                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+                            skip_type='time_uniform', method='singlestep')
+            - For **guided sampling with large guidance scale** by DPMs:
+                Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+                e.g.
+                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+                            skip_type='time_uniform', method='multistep')
+        We support three types of `skip_type`:
+            - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+            - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+            - 'time_quadratic': quadratic time for the time steps.
+        =====================================================
+        Args:
+            x: A pytorch tensor. The initial value at time `t_start`
+                e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+            steps: A `int`. The total number of function evaluations (NFE).
+            t_start: A `float`. The starting time of the sampling.
+                If `T` is None, we use self.noise_schedule.T (default is 1.0).
+            t_end: A `float`. The ending time of the sampling.
+                If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+                e.g. if total_N == 1000, we have `t_end` == 1e-3.
+                For discrete-time DPMs:
+                    - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+                For continuous-time DPMs:
+                    - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+            order: A `int`. The order of DPM-Solver.
+            skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+            method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+            denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+                Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+                This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+                score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+                for diffusion models sampling by diffusion SDEs for low-resolutional images
+                (such as CIFAR-10). However, we observed that such trick does not matter for
+                high-resolutional images. As it needs an additional NFE, we do not recommend
+                it for high-resolutional images.
+            lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+                Only valid for `method=multistep` and `steps < 15`. We empirically find that
+                this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+                (especially for steps <= 10). So we recommend to set it to be `True`.
+            solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+            atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+            rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+        Returns:
+            x_end: A pytorch tensor. The approximated solution at time `t_end`.
+        """
+        t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+        t_T = self.noise_schedule.T if t_start is None else t_start
+        device = x.device
+        if method == 'adaptive':
+            with torch.no_grad():
+                x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+                                             solver_type=solver_type)
+        elif method == 'multistep':
+            assert steps >= order
+            timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+            assert timesteps.shape[0] - 1 == steps
+            with torch.no_grad():
+                vec_t = timesteps[0].expand((x.shape[0]))
+                model_prev_list = [self.model_fn(x, vec_t)]
+                t_prev_list = [vec_t]
+                # Init the first `order` values by lower order multistep DPM-Solver.
+                for init_order in tqdm(range(1, order), desc="DPM init order"):
+                    vec_t = timesteps[init_order].expand(x.shape[0])
+                    x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+                                                         solver_type=solver_type)
+                    model_prev_list.append(self.model_fn(x, vec_t))
+                    t_prev_list.append(vec_t)
+                # Compute the remaining values by `order`-th order multistep DPM-Solver.
+                for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+                    vec_t = timesteps[step].expand(x.shape[0])
+                    if lower_order_final and steps < 15:
+                        step_order = min(order, steps + 1 - step)
+                    else:
+                        step_order = order
+                    x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+                                                         solver_type=solver_type)
+                    for i in range(order - 1):
+                        t_prev_list[i] = t_prev_list[i + 1]
+                        model_prev_list[i] = model_prev_list[i + 1]
+                    t_prev_list[-1] = vec_t
+                    # We do not need to evaluate the final model value.
+                    if step < steps:
+                        model_prev_list[-1] = self.model_fn(x, vec_t)
+        elif method in ['singlestep', 'singlestep_fixed']:
+            if method == 'singlestep':
+                timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+                                                                                              skip_type=skip_type,
+                                                                                              t_T=t_T, t_0=t_0,
+                                                                                              device=device)
+            elif method == 'singlestep_fixed':
+                K = steps // order
+                orders = [order, ] * K
+                timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+            for i, order in enumerate(orders):
+                t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+                timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+                                                      N=order, device=device)
+                lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+                vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+                h = lambda_inner[-1] - lambda_inner[0]
+                r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+                r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+                x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+        if denoise_to_zero:
+            x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+        return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+    """
+    A piecewise linear function y = f(x), using xp and yp as keypoints.
+    We implement f(x) in a differentiable way (i.e. applicable for autograd).
+    The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+    Args:
+        x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+        xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+        yp: PyTorch tensor with shape [C, K].
+    Returns:
+        The function values f(x), with shape [N, C].
+    """
+    N, K = x.shape[0], xp.shape[1]
+    all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+    sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+    x_idx = torch.argmin(x_indices, dim=2)
+    cand_start_idx = x_idx - 1
+    start_idx = torch.where(
+        torch.eq(x_idx, 0),
+        torch.tensor(1, device=x.device),
+        torch.where(
+            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+        ),
+    )
+    end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+    start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+    end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+    start_idx2 = torch.where(
+        torch.eq(x_idx, 0),
+        torch.tensor(0, device=x.device),
+        torch.where(
+            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+        ),
+    )
+    y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+    start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+    end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+    cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+    return cand
+
+
+def expand_dims(v, dims):
+    """
+    Expand the tensor `v` to the dim `dims`.
+    Args:
+        `v`: a PyTorch tensor with shape [N].
+        `dim`: a `int`.
+    Returns:
+        a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+    """
+    return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,87 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+MODEL_TYPES = {
+    "eps": "noise",
+    "v": "v"
+}
+
+
+class DPMSolverSampler(object):
+    def __init__(self, model, **kwargs):
+        super().__init__()
+        self.model = model
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+        self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None,
+               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+
+        print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+        device = self.model.betas.device
+        if x_T is None:
+            img = torch.randn(size, device=device)
+        else:
+            img = x_T
+
+        ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+        model_fn = model_wrapper(
+            lambda x, t, c: self.model.apply_model(x, t, c),
+            ns,
+            model_type=MODEL_TYPES[self.model.parameterization],
+            guidance_type="classifier-free",
+            condition=conditioning,
+            unconditional_condition=unconditional_conditioning,
+            guidance_scale=unconditional_guidance_scale,
+        )
+
+        dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+        x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+
+        return x.to(device), None
\ No newline at end of file
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
new file mode 100644
index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae
--- /dev/null
+++ b/ldm/models/diffusion/plms.py
@@ -0,0 +1,244 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
+
+
+class PLMSSampler(object):
+    def __init__(self, model, schedule="linear", **kwargs):
+        super().__init__()
+        self.model = model
+        self.ddpm_num_timesteps = model.num_timesteps
+        self.schedule = schedule
+
+    def register_buffer(self, name, attr):
+        if type(attr) == torch.Tensor:
+            if attr.device != torch.device("cuda"):
+                attr = attr.to(torch.device("cuda"))
+        setattr(self, name, attr)
+
+    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+        if ddim_eta != 0:
+            raise ValueError('ddim_eta must be 0 for PLMS')
+        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+        alphas_cumprod = self.model.alphas_cumprod
+        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+        self.register_buffer('betas', to_torch(self.model.betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+        # ddim sampling parameters
+        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+                                                                                   ddim_timesteps=self.ddim_timesteps,
+                                                                                   eta=ddim_eta,verbose=verbose)
+        self.register_buffer('ddim_sigmas', ddim_sigmas)
+        self.register_buffer('ddim_alphas', ddim_alphas)
+        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+    @torch.no_grad()
+    def sample(self,
+               S,
+               batch_size,
+               shape,
+               conditioning=None,
+               callback=None,
+               normals_sequence=None,
+               img_callback=None,
+               quantize_x0=False,
+               eta=0.,
+               mask=None,
+               x0=None,
+               temperature=1.,
+               noise_dropout=0.,
+               score_corrector=None,
+               corrector_kwargs=None,
+               verbose=True,
+               x_T=None,
+               log_every_t=100,
+               unconditional_guidance_scale=1.,
+               unconditional_conditioning=None,
+               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+               dynamic_threshold=None,
+               **kwargs
+               ):
+        if conditioning is not None:
+            if isinstance(conditioning, dict):
+                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+                if cbs != batch_size:
+                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+            else:
+                if conditioning.shape[0] != batch_size:
+                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+        # sampling
+        C, H, W = shape
+        size = (batch_size, C, H, W)
+        print(f'Data shape for PLMS sampling is {size}')
+
+        samples, intermediates = self.plms_sampling(conditioning, size,
+                                                    callback=callback,
+                                                    img_callback=img_callback,
+                                                    quantize_denoised=quantize_x0,
+                                                    mask=mask, x0=x0,
+                                                    ddim_use_original_steps=False,
+                                                    noise_dropout=noise_dropout,
+                                                    temperature=temperature,
+                                                    score_corrector=score_corrector,
+                                                    corrector_kwargs=corrector_kwargs,
+                                                    x_T=x_T,
+                                                    log_every_t=log_every_t,
+                                                    unconditional_guidance_scale=unconditional_guidance_scale,
+                                                    unconditional_conditioning=unconditional_conditioning,
+                                                    dynamic_threshold=dynamic_threshold,
+                                                    )
+        return samples, intermediates
+
+    @torch.no_grad()
+    def plms_sampling(self, cond, shape,
+                      x_T=None, ddim_use_original_steps=False,
+                      callback=None, timesteps=None, quantize_denoised=False,
+                      mask=None, x0=None, img_callback=None, log_every_t=100,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None,
+                      dynamic_threshold=None):
+        device = self.model.betas.device
+        b = shape[0]
+        if x_T is None:
+            img = torch.randn(shape, device=device)
+        else:
+            img = x_T
+
+        if timesteps is None:
+            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+        elif timesteps is not None and not ddim_use_original_steps:
+            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+            timesteps = self.ddim_timesteps[:subset_end]
+
+        intermediates = {'x_inter': [img], 'pred_x0': [img]}
+        time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+        print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+        iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+        old_eps = []
+
+        for i, step in enumerate(iterator):
+            index = total_steps - i - 1
+            ts = torch.full((b,), step, device=device, dtype=torch.long)
+            ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+            if mask is not None:
+                assert x0 is not None
+                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
+                img = img_orig * mask + (1. - mask) * img
+
+            outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+                                      quantize_denoised=quantize_denoised, temperature=temperature,
+                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
+                                      corrector_kwargs=corrector_kwargs,
+                                      unconditional_guidance_scale=unconditional_guidance_scale,
+                                      unconditional_conditioning=unconditional_conditioning,
+                                      old_eps=old_eps, t_next=ts_next,
+                                      dynamic_threshold=dynamic_threshold)
+            img, pred_x0, e_t = outs
+            old_eps.append(e_t)
+            if len(old_eps) >= 4:
+                old_eps.pop(0)
+            if callback: callback(i)
+            if img_callback: img_callback(pred_x0, i)
+
+            if index % log_every_t == 0 or index == total_steps - 1:
+                intermediates['x_inter'].append(img)
+                intermediates['pred_x0'].append(pred_x0)
+
+        return img, intermediates
+
+    @torch.no_grad()
+    def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+                      unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+                      dynamic_threshold=None):
+        b, *_, device = *x.shape, x.device
+
+        def get_model_output(x, t):
+            if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+                e_t = self.model.apply_model(x, t, c)
+            else:
+                x_in = torch.cat([x] * 2)
+                t_in = torch.cat([t] * 2)
+                c_in = torch.cat([unconditional_conditioning, c])
+                e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+                e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+            if score_corrector is not None:
+                assert self.model.parameterization == "eps"
+                e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+            return e_t
+
+        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+        def get_x_prev_and_pred_x0(e_t, index):
+            # select parameters corresponding to the currently considered timestep
+            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+            a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+            sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+            # current prediction for x_0
+            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+            if quantize_denoised:
+                pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+            if dynamic_threshold is not None:
+                pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+            # direction pointing to x_t
+            dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+            noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+            if noise_dropout > 0.:
+                noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+            x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+            return x_prev, pred_x0
+
+        e_t = get_model_output(x, t)
+        if len(old_eps) == 0:
+            # Pseudo Improved Euler (2nd order)
+            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+            e_t_next = get_model_output(x_prev, t_next)
+            e_t_prime = (e_t + e_t_next) / 2
+        elif len(old_eps) == 1:
+            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (3 * e_t - old_eps[-1]) / 2
+        elif len(old_eps) == 2:
+            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+        elif len(old_eps) >= 3:
+            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+            e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+        return x_prev, pred_x0, e_t
diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33
--- /dev/null
+++ b/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,22 @@
+import torch
+import numpy as np
+
+
+def append_dims(x, target_dims):
+    """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+    From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+    dims_to_append = target_dims - x.ndim
+    if dims_to_append < 0:
+        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+    return x[(...,) + (None,) * dims_to_append]
+
+
+def norm_thresholding(x0, value):
+    s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+    return x0 * (value / s)
+
+
+def spatial_norm_thresholding(x0, value):
+    # b c h w
+    s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+    return x0 * (value / s)
\ No newline at end of file
diff --git a/ldm/modules/__pycache__/attention.cpython-38.pyc b/ldm/modules/__pycache__/attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d064110b88ac810418e16f5a4b31453ab5447ddd
Binary files /dev/null and b/ldm/modules/__pycache__/attention.cpython-38.pyc differ
diff --git a/ldm/modules/__pycache__/ema.cpython-38.pyc b/ldm/modules/__pycache__/ema.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c3651e3fae7304b3bfc3187f4ec6393384fc438
Binary files /dev/null and b/ldm/modules/__pycache__/ema.cpython-38.pyc differ
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..509cd873768f0dd75a75ab3fcdd652822b12b59f
--- /dev/null
+++ b/ldm/modules/attention.py
@@ -0,0 +1,341 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+try:
+    import xformers
+    import xformers.ops
+    XFORMERS_IS_AVAILBLE = True
+except:
+    XFORMERS_IS_AVAILBLE = False
+
+# CrossAttn precision handling
+import os
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+def exists(val):
+    return val is not None
+
+
+def uniq(arr):
+    return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+    return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+    dim = tensor.shape[-1]
+    std = 1 / math.sqrt(dim)
+    tensor.uniform_(-std, std)
+    return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+    def __init__(self, dim_in, dim_out):
+        super().__init__()
+        self.proj = nn.Linear(dim_in, dim_out * 2)
+
+    def forward(self, x):
+        x, gate = self.proj(x).chunk(2, dim=-1)
+        return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+        super().__init__()
+        inner_dim = int(dim * mult)
+        dim_out = default(dim_out, dim)
+        project_in = nn.Sequential(
+            nn.Linear(dim, inner_dim),
+            nn.GELU()
+        ) if not glu else GEGLU(dim, inner_dim)
+
+        self.net = nn.Sequential(
+            project_in,
+            nn.Dropout(dropout),
+            nn.Linear(inner_dim, dim_out)
+        )
+
+    def forward(self, x):
+        return self.net(x)
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def Normalize(in_channels):
+    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class SpatialSelfAttention(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b,c,h,w = q.shape
+        q = rearrange(q, 'b c h w -> b (h w) c')
+        k = rearrange(k, 'b c h w -> b c (h w)')
+        w_ = torch.einsum('bij,bjk->bik', q, k)
+
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = rearrange(v, 'b c h w -> b c (h w)')
+        w_ = rearrange(w_, 'b i j -> b j i')
+        h_ = torch.einsum('bij,bjk->bik', v, w_)
+        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+
+class CrossAttention(nn.Module):
+    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+        super().__init__()
+        inner_dim = dim_head * heads
+        context_dim = default(context_dim, query_dim)
+
+        self.scale = dim_head ** -0.5
+        self.heads = heads
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+        self.to_out = nn.Sequential(
+            nn.Linear(inner_dim, query_dim),
+            nn.Dropout(dropout)
+        )
+
+    def forward(self, x, context=None, mask=None):
+        h = self.heads
+
+        q = self.to_q(x)
+        context = default(context, x)
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+        # force cast to fp32 to avoid overflowing
+        if _ATTN_PRECISION =="fp32":
+            with torch.autocast(enabled=False, device_type = 'cuda'):
+                q, k = q.float(), k.float()
+                sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+        else:
+            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+        
+        del q, k
+    
+        if exists(mask):
+            mask = rearrange(mask, 'b ... -> b (...)')
+            max_neg_value = -torch.finfo(sim.dtype).max
+            mask = repeat(mask, 'b j -> (b h) () j', h=h)
+            sim.masked_fill_(~mask, max_neg_value)
+
+        # attention, what we cannot get enough of
+        sim = sim.softmax(dim=-1)
+
+        out = einsum('b i j, b j d -> b i d', sim, v)
+        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+        return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+    # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+        super().__init__()
+        print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+              f"{heads} heads.")
+        inner_dim = dim_head * heads
+        context_dim = default(context_dim, query_dim)
+
+        self.heads = heads
+        self.dim_head = dim_head
+
+        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+        self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+        self.attention_op: Optional[Any] = None
+
+    def forward(self, x, context=None, mask=None):
+        q = self.to_q(x)
+        context = default(context, x)
+        k = self.to_k(context)
+        v = self.to_v(context)
+
+        b, _, _ = q.shape
+        q, k, v = map(
+            lambda t: t.unsqueeze(3)
+            .reshape(b, t.shape[1], self.heads, self.dim_head)
+            .permute(0, 2, 1, 3)
+            .reshape(b * self.heads, t.shape[1], self.dim_head)
+            .contiguous(),
+            (q, k, v),
+        )
+
+        # actually compute the attention, what we cannot get enough of
+        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+        if exists(mask):
+            raise NotImplementedError
+        out = (
+            out.unsqueeze(0)
+            .reshape(b, self.heads, out.shape[1], self.dim_head)
+            .permute(0, 2, 1, 3)
+            .reshape(b, out.shape[1], self.heads * self.dim_head)
+        )
+        return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+    ATTENTION_MODES = {
+        "softmax": CrossAttention,  # vanilla attention
+        "softmax-xformers": MemoryEfficientCrossAttention
+    }
+    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+                 disable_self_attn=False):
+        super().__init__()
+        attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+        assert attn_mode in self.ATTENTION_MODES
+        attn_cls = self.ATTENTION_MODES[attn_mode]
+        self.disable_self_attn = disable_self_attn
+        self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+                              context_dim=context_dim if self.disable_self_attn else None)  # is a self-attention if not self.disable_self_attn
+        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+        self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+                              heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+        self.norm3 = nn.LayerNorm(dim)
+        self.checkpoint = checkpoint
+
+    def forward(self, x, context=None):
+        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+    def _forward(self, x, context=None):
+        x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+        x = self.attn2(self.norm2(x), context=context) + x
+        x = self.ff(self.norm3(x)) + x
+        return x
+
+
+class SpatialTransformer(nn.Module):
+    """
+    Transformer block for image-like data.
+    First, project the input (aka embedding)
+    and reshape to b, t, d.
+    Then apply standard transformer action.
+    Finally, reshape to image
+    NEW: use_linear for more efficiency instead of the 1x1 convs
+    """
+    def __init__(self, in_channels, n_heads, d_head,
+                 depth=1, dropout=0., context_dim=None,
+                 disable_self_attn=False, use_linear=False,
+                 use_checkpoint=True):
+        super().__init__()
+        if exists(context_dim) and not isinstance(context_dim, list):
+            context_dim = [context_dim]
+        self.in_channels = in_channels
+        inner_dim = n_heads * d_head
+        self.norm = Normalize(in_channels)
+        if not use_linear:
+            self.proj_in = nn.Conv2d(in_channels,
+                                     inner_dim,
+                                     kernel_size=1,
+                                     stride=1,
+                                     padding=0)
+        else:
+            self.proj_in = nn.Linear(in_channels, inner_dim)
+
+        self.transformer_blocks = nn.ModuleList(
+            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+                                   disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+                for d in range(depth)]
+        )
+        if not use_linear:
+            self.proj_out = zero_module(nn.Conv2d(inner_dim,
+                                                  in_channels,
+                                                  kernel_size=1,
+                                                  stride=1,
+                                                  padding=0))
+        else:
+            self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+        self.use_linear = use_linear
+
+    def forward(self, x, context=None):
+        # note: if no context is given, cross-attention defaults to self-attention
+        if not isinstance(context, list):
+            context = [context]
+        b, c, h, w = x.shape
+        x_in = x
+        x = self.norm(x)
+        if not self.use_linear:
+            x = self.proj_in(x)
+        x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+        if self.use_linear:
+            x = self.proj_in(x)
+        for i, block in enumerate(self.transformer_blocks):
+            x = block(x, context=context[i])
+        if self.use_linear:
+            x = self.proj_out(x)
+        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+        if not self.use_linear:
+            x = self.proj_out(x)
+        return x + x_in
+
diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db34cb9946edc10e5ce3249859299d0c06bce52e
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8753c3873ccf481f12273e93671a37d192de0b4
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc differ
diff --git a/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2ef50bf2576c78a6e60e6b82e4a471a03fc3ebe0
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc differ
diff --git a/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e65501d679fc4f8f9c062d61504e267c50b1b38
Binary files /dev/null and b/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc differ
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b089eebbe1676d8249005bb9def002ff5180715b
--- /dev/null
+++ b/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,852 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+
+from ldm.modules.attention import MemoryEfficientCrossAttention
+
+try:
+    import xformers
+    import xformers.ops
+    XFORMERS_IS_AVAILBLE = True
+except:
+    XFORMERS_IS_AVAILBLE = False
+    print("No module 'xformers'. Proceeding without it.")
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models:
+    From Fairseq.
+    Build sinusoidal embeddings.
+    This matches the implementation in tensor2tensor, but differs slightly
+    from the description in Section 3.5 of "Attention Is All You Need".
+    """
+    assert len(timesteps.shape) == 1
+
+    half_dim = embedding_dim // 2
+    emb = math.log(10000) / (half_dim - 1)
+    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+    emb = emb.to(device=timesteps.device)
+    emb = timesteps.float()[:, None] * emb[None, :]
+    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+    if embedding_dim % 2 == 1:  # zero pad
+        emb = torch.nn.functional.pad(emb, (0,1,0,0))
+    return emb
+
+
+def nonlinearity(x):
+    # swish
+    return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+        if self.with_conv:
+            x = self.conv(x)
+        return x
+
+
+class Downsample(nn.Module):
+    def __init__(self, in_channels, with_conv):
+        super().__init__()
+        self.with_conv = with_conv
+        if self.with_conv:
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=3,
+                                        stride=2,
+                                        padding=0)
+
+    def forward(self, x):
+        if self.with_conv:
+            pad = (0,1,0,1)
+            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+            x = self.conv(x)
+        else:
+            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+        return x
+
+
+class ResnetBlock(nn.Module):
+    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+                 dropout, temb_channels=512):
+        super().__init__()
+        self.in_channels = in_channels
+        out_channels = in_channels if out_channels is None else out_channels
+        self.out_channels = out_channels
+        self.use_conv_shortcut = conv_shortcut
+
+        self.norm1 = Normalize(in_channels)
+        self.conv1 = torch.nn.Conv2d(in_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if temb_channels > 0:
+            self.temb_proj = torch.nn.Linear(temb_channels,
+                                             out_channels)
+        self.norm2 = Normalize(out_channels)
+        self.dropout = torch.nn.Dropout(dropout)
+        self.conv2 = torch.nn.Conv2d(out_channels,
+                                     out_channels,
+                                     kernel_size=3,
+                                     stride=1,
+                                     padding=1)
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                self.conv_shortcut = torch.nn.Conv2d(in_channels,
+                                                     out_channels,
+                                                     kernel_size=3,
+                                                     stride=1,
+                                                     padding=1)
+            else:
+                self.nin_shortcut = torch.nn.Conv2d(in_channels,
+                                                    out_channels,
+                                                    kernel_size=1,
+                                                    stride=1,
+                                                    padding=0)
+
+    def forward(self, x, temb):
+        h = x
+        h = self.norm1(h)
+        h = nonlinearity(h)
+        h = self.conv1(h)
+
+        if temb is not None:
+            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+        h = self.norm2(h)
+        h = nonlinearity(h)
+        h = self.dropout(h)
+        h = self.conv2(h)
+
+        if self.in_channels != self.out_channels:
+            if self.use_conv_shortcut:
+                x = self.conv_shortcut(x)
+            else:
+                x = self.nin_shortcut(x)
+
+        return x+h
+
+
+class AttnBlock(nn.Module):
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        b,c,h,w = q.shape
+        q = q.reshape(b,c,h*w)
+        q = q.permute(0,2,1)   # b,hw,c
+        k = k.reshape(b,c,h*w) # b,c,hw
+        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+        w_ = w_ * (int(c)**(-0.5))
+        w_ = torch.nn.functional.softmax(w_, dim=2)
+
+        # attend to values
+        v = v.reshape(b,c,h*w)
+        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
+        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+        h_ = h_.reshape(b,c,h,w)
+
+        h_ = self.proj_out(h_)
+
+        return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+    """
+        Uses xformers efficient implementation,
+        see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+        Note: this is a single-head self-attention operation
+    """
+    #
+    def __init__(self, in_channels):
+        super().__init__()
+        self.in_channels = in_channels
+
+        self.norm = Normalize(in_channels)
+        self.q = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.k = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.v = torch.nn.Conv2d(in_channels,
+                                 in_channels,
+                                 kernel_size=1,
+                                 stride=1,
+                                 padding=0)
+        self.proj_out = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=1,
+                                        stride=1,
+                                        padding=0)
+        self.attention_op: Optional[Any] = None
+
+    def forward(self, x):
+        h_ = x
+        h_ = self.norm(h_)
+        q = self.q(h_)
+        k = self.k(h_)
+        v = self.v(h_)
+
+        # compute attention
+        B, C, H, W = q.shape
+        q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+        q, k, v = map(
+            lambda t: t.unsqueeze(3)
+            .reshape(B, t.shape[1], 1, C)
+            .permute(0, 2, 1, 3)
+            .reshape(B * 1, t.shape[1], C)
+            .contiguous(),
+            (q, k, v),
+        )
+        out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+        out = (
+            out.unsqueeze(0)
+            .reshape(B, 1, out.shape[1], C)
+            .permute(0, 2, 1, 3)
+            .reshape(B, out.shape[1], C)
+        )
+        out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+        out = self.proj_out(out)
+        return x+out
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+    def forward(self, x, context=None, mask=None):
+        b, c, h, w = x.shape
+        x = rearrange(x, 'b c h w -> b (h w) c')
+        out = super().forward(x, context=context, mask=mask)
+        out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
+        return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+    assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
+    if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+        attn_type = "vanilla-xformers"
+    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+    if attn_type == "vanilla":
+        assert attn_kwargs is None
+        return AttnBlock(in_channels)
+    elif attn_type == "vanilla-xformers":
+        print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+        return MemoryEfficientAttnBlock(in_channels)
+    elif type == "memory-efficient-cross-attn":
+        attn_kwargs["query_dim"] = in_channels
+        return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+    elif attn_type == "none":
+        return nn.Identity(in_channels)
+    else:
+        raise NotImplementedError()
+
+
+class Model(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+        super().__init__()
+        if use_linear_attn: attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = self.ch*4
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        self.use_timestep = use_timestep
+        if self.use_timestep:
+            # timestep embedding
+            self.temb = nn.Module()
+            self.temb.dense = nn.ModuleList([
+                torch.nn.Linear(self.ch,
+                                self.temb_ch),
+                torch.nn.Linear(self.temb_ch,
+                                self.temb_ch),
+            ])
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            skip_in = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                if i_block == self.num_res_blocks:
+                    skip_in = ch*in_ch_mult[i_level]
+                block.append(ResnetBlock(in_channels=block_in+skip_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x, t=None, context=None):
+        #assert x.shape[2] == x.shape[3] == self.resolution
+        if context is not None:
+            # assume aligned context, cat along channel axis
+            x = torch.cat((x, context), dim=1)
+        if self.use_timestep:
+            # timestep embedding
+            assert t is not None
+            temb = get_timestep_embedding(t, self.ch)
+            temb = self.temb.dense[0](temb)
+            temb = nonlinearity(temb)
+            temb = self.temb.dense[1](temb)
+        else:
+            temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](
+                    torch.cat([h, hs.pop()], dim=1), temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+    def get_last_layer(self):
+        return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+                 **ignore_kwargs):
+        super().__init__()
+        if use_linear_attn: attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+
+        # downsampling
+        self.conv_in = torch.nn.Conv2d(in_channels,
+                                       self.ch,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        curr_res = resolution
+        in_ch_mult = (1,)+tuple(ch_mult)
+        self.in_ch_mult = in_ch_mult
+        self.down = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_in = ch*in_ch_mult[i_level]
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            down = nn.Module()
+            down.block = block
+            down.attn = attn
+            if i_level != self.num_resolutions-1:
+                down.downsample = Downsample(block_in, resamp_with_conv)
+                curr_res = curr_res // 2
+            self.down.append(down)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        2*z_channels if double_z else z_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        # timestep embedding
+        temb = None
+
+        # downsampling
+        hs = [self.conv_in(x)]
+        for i_level in range(self.num_resolutions):
+            for i_block in range(self.num_res_blocks):
+                h = self.down[i_level].block[i_block](hs[-1], temb)
+                if len(self.down[i_level].attn) > 0:
+                    h = self.down[i_level].attn[i_block](h)
+                hs.append(h)
+            if i_level != self.num_resolutions-1:
+                hs.append(self.down[i_level].downsample(hs[-1]))
+
+        # middle
+        h = hs[-1]
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # end
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class Decoder(nn.Module):
+    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+                 attn_type="vanilla", **ignorekwargs):
+        super().__init__()
+        if use_linear_attn: attn_type = "linear"
+        self.ch = ch
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        self.resolution = resolution
+        self.in_channels = in_channels
+        self.give_pre_end = give_pre_end
+        self.tanh_out = tanh_out
+
+        # compute in_ch_mult, block_in and curr_res at lowest res
+        in_ch_mult = (1,)+tuple(ch_mult)
+        block_in = ch*ch_mult[self.num_resolutions-1]
+        curr_res = resolution // 2**(self.num_resolutions-1)
+        self.z_shape = (1,z_channels,curr_res,curr_res)
+        print("Working with z of shape {} = {} dimensions.".format(
+            self.z_shape, np.prod(self.z_shape)))
+
+        # z to block_in
+        self.conv_in = torch.nn.Conv2d(z_channels,
+                                       block_in,
+                                       kernel_size=3,
+                                       stride=1,
+                                       padding=1)
+
+        # middle
+        self.mid = nn.Module()
+        self.mid.block_1 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+        self.mid.block_2 = ResnetBlock(in_channels=block_in,
+                                       out_channels=block_in,
+                                       temb_channels=self.temb_ch,
+                                       dropout=dropout)
+
+        # upsampling
+        self.up = nn.ModuleList()
+        for i_level in reversed(range(self.num_resolutions)):
+            block = nn.ModuleList()
+            attn = nn.ModuleList()
+            block_out = ch*ch_mult[i_level]
+            for i_block in range(self.num_res_blocks+1):
+                block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+                if curr_res in attn_resolutions:
+                    attn.append(make_attn(block_in, attn_type=attn_type))
+            up = nn.Module()
+            up.block = block
+            up.attn = attn
+            if i_level != 0:
+                up.upsample = Upsample(block_in, resamp_with_conv)
+                curr_res = curr_res * 2
+            self.up.insert(0, up) # prepend to get consistent order
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_ch,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, z):
+        #assert z.shape[1:] == self.z_shape[1:]
+        self.last_z_shape = z.shape
+
+        # timestep embedding
+        temb = None
+
+        # z to block_in
+        h = self.conv_in(z)
+
+        # middle
+        h = self.mid.block_1(h, temb)
+        h = self.mid.attn_1(h)
+        h = self.mid.block_2(h, temb)
+
+        # upsampling
+        for i_level in reversed(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks+1):
+                h = self.up[i_level].block[i_block](h, temb)
+                if len(self.up[i_level].attn) > 0:
+                    h = self.up[i_level].attn[i_block](h)
+            if i_level != 0:
+                h = self.up[i_level].upsample(h)
+
+        # end
+        if self.give_pre_end:
+            return h
+
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        if self.tanh_out:
+            h = torch.tanh(h)
+        return h
+
+
+class SimpleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, *args, **kwargs):
+        super().__init__()
+        self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+                                     ResnetBlock(in_channels=in_channels,
+                                                 out_channels=2 * in_channels,
+                                                 temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=2 * in_channels,
+                                                out_channels=4 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     ResnetBlock(in_channels=4 * in_channels,
+                                                out_channels=2 * in_channels,
+                                                temb_channels=0, dropout=0.0),
+                                     nn.Conv2d(2*in_channels, in_channels, 1),
+                                     Upsample(in_channels, with_conv=True)])
+        # end
+        self.norm_out = Normalize(in_channels)
+        self.conv_out = torch.nn.Conv2d(in_channels,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        for i, layer in enumerate(self.model):
+            if i in [1,2,3]:
+                x = layer(x, None)
+            else:
+                x = layer(x)
+
+        h = self.norm_out(x)
+        h = nonlinearity(h)
+        x = self.conv_out(h)
+        return x
+
+
+class UpsampleDecoder(nn.Module):
+    def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+                 ch_mult=(2,2), dropout=0.0):
+        super().__init__()
+        # upsampling
+        self.temb_ch = 0
+        self.num_resolutions = len(ch_mult)
+        self.num_res_blocks = num_res_blocks
+        block_in = in_channels
+        curr_res = resolution // 2 ** (self.num_resolutions - 1)
+        self.res_blocks = nn.ModuleList()
+        self.upsample_blocks = nn.ModuleList()
+        for i_level in range(self.num_resolutions):
+            res_block = []
+            block_out = ch * ch_mult[i_level]
+            for i_block in range(self.num_res_blocks + 1):
+                res_block.append(ResnetBlock(in_channels=block_in,
+                                         out_channels=block_out,
+                                         temb_channels=self.temb_ch,
+                                         dropout=dropout))
+                block_in = block_out
+            self.res_blocks.append(nn.ModuleList(res_block))
+            if i_level != self.num_resolutions - 1:
+                self.upsample_blocks.append(Upsample(block_in, True))
+                curr_res = curr_res * 2
+
+        # end
+        self.norm_out = Normalize(block_in)
+        self.conv_out = torch.nn.Conv2d(block_in,
+                                        out_channels,
+                                        kernel_size=3,
+                                        stride=1,
+                                        padding=1)
+
+    def forward(self, x):
+        # upsampling
+        h = x
+        for k, i_level in enumerate(range(self.num_resolutions)):
+            for i_block in range(self.num_res_blocks + 1):
+                h = self.res_blocks[i_level][i_block](h, None)
+            if i_level != self.num_resolutions - 1:
+                h = self.upsample_blocks[k](h)
+        h = self.norm_out(h)
+        h = nonlinearity(h)
+        h = self.conv_out(h)
+        return h
+
+
+class LatentRescaler(nn.Module):
+    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+        super().__init__()
+        # residual block, interpolate, residual block
+        self.factor = factor
+        self.conv_in = nn.Conv2d(in_channels,
+                                 mid_channels,
+                                 kernel_size=3,
+                                 stride=1,
+                                 padding=1)
+        self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+                                                     out_channels=mid_channels,
+                                                     temb_channels=0,
+                                                     dropout=0.0) for _ in range(depth)])
+        self.attn = AttnBlock(mid_channels)
+        self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+                                                     out_channels=mid_channels,
+                                                     temb_channels=0,
+                                                     dropout=0.0) for _ in range(depth)])
+
+        self.conv_out = nn.Conv2d(mid_channels,
+                                  out_channels,
+                                  kernel_size=1,
+                                  )
+
+    def forward(self, x):
+        x = self.conv_in(x)
+        for block in self.res_block1:
+            x = block(x, None)
+        x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+        x = self.attn(x)
+        for block in self.res_block2:
+            x = block(x, None)
+        x = self.conv_out(x)
+        return x
+
+
+class MergedRescaleEncoder(nn.Module):
+    def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+                 attn_resolutions, dropout=0.0, resamp_with_conv=True,
+                 ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+        super().__init__()
+        intermediate_chn = ch * ch_mult[-1]
+        self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+                               z_channels=intermediate_chn, double_z=False, resolution=resolution,
+                               attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+                               out_ch=None)
+        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+                                       mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+    def forward(self, x):
+        x = self.encoder(x)
+        x = self.rescaler(x)
+        return x
+
+
+class MergedRescaleDecoder(nn.Module):
+    def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+                 dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+        super().__init__()
+        tmp_chn = z_channels*ch_mult[-1]
+        self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+                               resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+                               ch_mult=ch_mult, resolution=resolution, ch=ch)
+        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+                                       out_channels=tmp_chn, depth=rescale_module_depth)
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Upsampler(nn.Module):
+    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+        super().__init__()
+        assert out_size >= in_size
+        num_blocks = int(np.log2(out_size//in_size))+1
+        factor_up = 1.+ (out_size % in_size)
+        print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+        self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+                                       out_channels=in_channels)
+        self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+                               attn_resolutions=[], in_channels=None, ch=in_channels,
+                               ch_mult=[ch_mult for _ in range(num_blocks)])
+
+    def forward(self, x):
+        x = self.rescaler(x)
+        x = self.decoder(x)
+        return x
+
+
+class Resize(nn.Module):
+    def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+        super().__init__()
+        self.with_conv = learned
+        self.mode = mode
+        if self.with_conv:
+            print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+            raise NotImplementedError()
+            assert in_channels is not None
+            # no asymmetric padding in torch conv, must do it ourselves
+            self.conv = torch.nn.Conv2d(in_channels,
+                                        in_channels,
+                                        kernel_size=4,
+                                        stride=2,
+                                        padding=1)
+
+    def forward(self, x, scale_factor=1.0):
+        if scale_factor==1.0:
+            return x
+        else:
+            x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+        return x
diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df6b5abfe8eff07f0c8e8703ba8aee90d45984b
--- /dev/null
+++ b/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,786 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+    checkpoint,
+    conv_nd,
+    linear,
+    avg_pool_nd,
+    zero_module,
+    normalization,
+    timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+    pass
+
+def convert_module_to_f32(x):
+    pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+    """
+    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+    """
+
+    def __init__(
+        self,
+        spacial_dim: int,
+        embed_dim: int,
+        num_heads_channels: int,
+        output_dim: int = None,
+    ):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+        self.num_heads = embed_dim // num_heads_channels
+        self.attention = QKVAttention(self.num_heads)
+
+    def forward(self, x):
+        b, c, *_spatial = x.shape
+        x = x.reshape(b, c, -1)  # NC(HW)
+        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)
+        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)
+        x = self.qkv_proj(x)
+        x = self.attention(x)
+        x = self.c_proj(x)
+        return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+    """
+    Any module where forward() takes timestep embeddings as a second argument.
+    """
+
+    @abstractmethod
+    def forward(self, x, emb):
+        """
+        Apply the module to `x` given `emb` timestep embeddings.
+        """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+    """
+    A sequential module that passes timestep embeddings to the children that
+    support it as an extra input.
+    """
+
+    def forward(self, x, emb, context=None):
+        for layer in self:
+            if isinstance(layer, TimestepBlock):
+                x = layer(x, emb)
+            elif isinstance(layer, SpatialTransformer):
+                x = layer(x, context)
+            else:
+                x = layer(x)
+        return x
+
+
+class Upsample(nn.Module):
+    """
+    An upsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 upsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        if use_conv:
+            self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        if self.dims == 3:
+            x = F.interpolate(
+                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+            )
+        else:
+            x = F.interpolate(x, scale_factor=2, mode="nearest")
+        if self.use_conv:
+            x = self.conv(x)
+        return x
+
+class TransposedUpsample(nn.Module):
+    'Learned 2x upsampling without padding'
+    def __init__(self, channels, out_channels=None, ks=5):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+
+        self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+    def forward(self,x):
+        return self.up(x)
+
+
+class Downsample(nn.Module):
+    """
+    A downsampling layer with an optional convolution.
+    :param channels: channels in the inputs and outputs.
+    :param use_conv: a bool determining if a convolution is applied.
+    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+                 downsampling occurs in the inner-two dimensions.
+    """
+
+    def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+        super().__init__()
+        self.channels = channels
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.dims = dims
+        stride = 2 if dims != 3 else (1, 2, 2)
+        if use_conv:
+            self.op = conv_nd(
+                dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+            )
+        else:
+            assert self.channels == self.out_channels
+            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+    def forward(self, x):
+        assert x.shape[1] == self.channels
+        return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+    """
+    A residual block that can optionally change the number of channels.
+    :param channels: the number of input channels.
+    :param emb_channels: the number of timestep embedding channels.
+    :param dropout: the rate of dropout.
+    :param out_channels: if specified, the number of out channels.
+    :param use_conv: if True and out_channels is specified, use a spatial
+        convolution instead of a smaller 1x1 convolution to change the
+        channels in the skip connection.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param use_checkpoint: if True, use gradient checkpointing on this module.
+    :param up: if True, use this block for upsampling.
+    :param down: if True, use this block for downsampling.
+    """
+
+    def __init__(
+        self,
+        channels,
+        emb_channels,
+        dropout,
+        out_channels=None,
+        use_conv=False,
+        use_scale_shift_norm=False,
+        dims=2,
+        use_checkpoint=False,
+        up=False,
+        down=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.emb_channels = emb_channels
+        self.dropout = dropout
+        self.out_channels = out_channels or channels
+        self.use_conv = use_conv
+        self.use_checkpoint = use_checkpoint
+        self.use_scale_shift_norm = use_scale_shift_norm
+
+        self.in_layers = nn.Sequential(
+            normalization(channels),
+            nn.SiLU(),
+            conv_nd(dims, channels, self.out_channels, 3, padding=1),
+        )
+
+        self.updown = up or down
+
+        if up:
+            self.h_upd = Upsample(channels, False, dims)
+            self.x_upd = Upsample(channels, False, dims)
+        elif down:
+            self.h_upd = Downsample(channels, False, dims)
+            self.x_upd = Downsample(channels, False, dims)
+        else:
+            self.h_upd = self.x_upd = nn.Identity()
+
+        self.emb_layers = nn.Sequential(
+            nn.SiLU(),
+            linear(
+                emb_channels,
+                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+            ),
+        )
+        self.out_layers = nn.Sequential(
+            normalization(self.out_channels),
+            nn.SiLU(),
+            nn.Dropout(p=dropout),
+            zero_module(
+                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+            ),
+        )
+
+        if self.out_channels == channels:
+            self.skip_connection = nn.Identity()
+        elif use_conv:
+            self.skip_connection = conv_nd(
+                dims, channels, self.out_channels, 3, padding=1
+            )
+        else:
+            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+    def forward(self, x, emb):
+        """
+        Apply the block to a Tensor, conditioned on a timestep embedding.
+        :param x: an [N x C x ...] Tensor of features.
+        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        return checkpoint(
+            self._forward, (x, emb), self.parameters(), self.use_checkpoint
+        )
+
+
+    def _forward(self, x, emb):
+        if self.updown:
+            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+            h = in_rest(x)
+            h = self.h_upd(h)
+            x = self.x_upd(x)
+            h = in_conv(h)
+        else:
+            h = self.in_layers(x)
+        emb_out = self.emb_layers(emb).type(h.dtype)
+        while len(emb_out.shape) < len(h.shape):
+            emb_out = emb_out[..., None]
+        if self.use_scale_shift_norm:
+            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+            scale, shift = th.chunk(emb_out, 2, dim=1)
+            h = out_norm(h) * (1 + scale) + shift
+            h = out_rest(h)
+        else:
+            h = h + emb_out
+            h = self.out_layers(h)
+        return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+    """
+    An attention block that allows spatial positions to attend to each other.
+    Originally ported from here, but adapted to the N-d case.
+    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+    """
+
+    def __init__(
+        self,
+        channels,
+        num_heads=1,
+        num_head_channels=-1,
+        use_checkpoint=False,
+        use_new_attention_order=False,
+    ):
+        super().__init__()
+        self.channels = channels
+        if num_head_channels == -1:
+            self.num_heads = num_heads
+        else:
+            assert (
+                channels % num_head_channels == 0
+            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+            self.num_heads = channels // num_head_channels
+        self.use_checkpoint = use_checkpoint
+        self.norm = normalization(channels)
+        self.qkv = conv_nd(1, channels, channels * 3, 1)
+        if use_new_attention_order:
+            # split qkv before split heads
+            self.attention = QKVAttention(self.num_heads)
+        else:
+            # split heads before split qkv
+            self.attention = QKVAttentionLegacy(self.num_heads)
+
+        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+    def forward(self, x):
+        return checkpoint(self._forward, (x,), self.parameters(), True)   # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+        #return pt_checkpoint(self._forward, x)  # pytorch
+
+    def _forward(self, x):
+        b, c, *spatial = x.shape
+        x = x.reshape(b, c, -1)
+        qkv = self.qkv(self.norm(x))
+        h = self.attention(qkv)
+        h = self.proj_out(h)
+        return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+    """
+    A counter for the `thop` package to count the operations in an
+    attention operation.
+    Meant to be used like:
+        macs, params = thop.profile(
+            model,
+            inputs=(inputs, timestamps),
+            custom_ops={QKVAttention: QKVAttention.count_flops},
+        )
+    """
+    b, c, *spatial = y[0].shape
+    num_spatial = int(np.prod(spatial))
+    # We perform two matmuls with the same number of ops.
+    # The first computes the weight matrix, the second computes
+    # the combination of the value vectors.
+    matmul_ops = 2 * b * (num_spatial ** 2) * c
+    model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+    """
+    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts", q * scale, k * scale
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v)
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+    """
+    A module which performs QKV attention and splits in a different order.
+    """
+
+    def __init__(self, n_heads):
+        super().__init__()
+        self.n_heads = n_heads
+
+    def forward(self, qkv):
+        """
+        Apply QKV attention.
+        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+        :return: an [N x (H * C) x T] tensor after attention.
+        """
+        bs, width, length = qkv.shape
+        assert width % (3 * self.n_heads) == 0
+        ch = width // (3 * self.n_heads)
+        q, k, v = qkv.chunk(3, dim=1)
+        scale = 1 / math.sqrt(math.sqrt(ch))
+        weight = th.einsum(
+            "bct,bcs->bts",
+            (q * scale).view(bs * self.n_heads, ch, length),
+            (k * scale).view(bs * self.n_heads, ch, length),
+        )  # More stable with f16 than dividing afterwards
+        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+        a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+        return a.reshape(bs, -1, length)
+
+    @staticmethod
+    def count_flops(model, _x, y):
+        return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+    """
+    The full UNet model with attention and timestep embedding.
+    :param in_channels: channels in the input Tensor.
+    :param model_channels: base channel count for the model.
+    :param out_channels: channels in the output Tensor.
+    :param num_res_blocks: number of residual blocks per downsample.
+    :param attention_resolutions: a collection of downsample rates at which
+        attention will take place. May be a set, list, or tuple.
+        For example, if this contains 4, then at 4x downsampling, attention
+        will be used.
+    :param dropout: the dropout probability.
+    :param channel_mult: channel multiplier for each level of the UNet.
+    :param conv_resample: if True, use learned convolutions for upsampling and
+        downsampling.
+    :param dims: determines if the signal is 1D, 2D, or 3D.
+    :param num_classes: if specified (as an int), then this model will be
+        class-conditional with `num_classes` classes.
+    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+    :param num_heads: the number of attention heads in each attention layer.
+    :param num_heads_channels: if specified, ignore num_heads and instead use
+                               a fixed channel width per attention head.
+    :param num_heads_upsample: works with num_heads to set a different number
+                               of heads for upsampling. Deprecated.
+    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+    :param resblock_updown: use residual blocks for up/downsampling.
+    :param use_new_attention_order: use a different attention pattern for potentially
+                                    increased efficiency.
+    """
+
+    def __init__(
+        self,
+        image_size,
+        in_channels,
+        model_channels,
+        out_channels,
+        num_res_blocks,
+        attention_resolutions,
+        dropout=0,
+        channel_mult=(1, 2, 4, 8),
+        conv_resample=True,
+        dims=2,
+        num_classes=None,
+        use_checkpoint=False,
+        use_fp16=False,
+        num_heads=-1,
+        num_head_channels=-1,
+        num_heads_upsample=-1,
+        use_scale_shift_norm=False,
+        resblock_updown=False,
+        use_new_attention_order=False,
+        use_spatial_transformer=False,    # custom transformer support
+        transformer_depth=1,              # custom transformer support
+        context_dim=None,                 # custom transformer support
+        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
+        legacy=True,
+        disable_self_attentions=None,
+        num_attention_blocks=None,
+        disable_middle_self_attn=False,
+        use_linear_in_transformer=False,
+    ):
+        super().__init__()
+        if use_spatial_transformer:
+            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+        if context_dim is not None:
+            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+            from omegaconf.listconfig import ListConfig
+            if type(context_dim) == ListConfig:
+                context_dim = list(context_dim)
+
+        if num_heads_upsample == -1:
+            num_heads_upsample = num_heads
+
+        if num_heads == -1:
+            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+        if num_head_channels == -1:
+            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+        self.image_size = image_size
+        self.in_channels = in_channels
+        self.model_channels = model_channels
+        self.out_channels = out_channels
+        if isinstance(num_res_blocks, int):
+            self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+        else:
+            if len(num_res_blocks) != len(channel_mult):
+                raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+                                 "as a list/tuple (per-level) with the same length as channel_mult")
+            self.num_res_blocks = num_res_blocks
+        if disable_self_attentions is not None:
+            # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+            assert len(disable_self_attentions) == len(channel_mult)
+        if num_attention_blocks is not None:
+            assert len(num_attention_blocks) == len(self.num_res_blocks)
+            assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+            print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+                  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+                  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+                  f"attention will still not be set.")
+
+        self.attention_resolutions = attention_resolutions
+        self.dropout = dropout
+        self.channel_mult = channel_mult
+        self.conv_resample = conv_resample
+        self.num_classes = num_classes
+        self.use_checkpoint = use_checkpoint
+        self.dtype = th.float16 if use_fp16 else th.float32
+        self.num_heads = num_heads
+        self.num_head_channels = num_head_channels
+        self.num_heads_upsample = num_heads_upsample
+        self.predict_codebook_ids = n_embed is not None
+
+        time_embed_dim = model_channels * 4
+        self.time_embed = nn.Sequential(
+            linear(model_channels, time_embed_dim),
+            nn.SiLU(),
+            linear(time_embed_dim, time_embed_dim),
+        )
+
+        if self.num_classes is not None:
+            if isinstance(self.num_classes, int):
+                self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+            elif self.num_classes == "continuous":
+                print("setting up linear c_adm embedding layer")
+                self.label_emb = nn.Linear(1, time_embed_dim)
+            else:
+                raise ValueError()
+
+        self.input_blocks = nn.ModuleList(
+            [
+                TimestepEmbedSequential(
+                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
+                )
+            ]
+        )
+        self._feature_size = model_channels
+        input_block_chans = [model_channels]
+        ch = model_channels
+        ds = 1
+        for level, mult in enumerate(channel_mult):
+            for nr in range(self.num_res_blocks[level]):
+                layers = [
+                    ResBlock(
+                        ch,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=mult * model_channels,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = mult * model_channels
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        #num_heads = 1
+                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            ) if not use_spatial_transformer else SpatialTransformer(
+                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint
+                            )
+                        )
+                self.input_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+                input_block_chans.append(ch)
+            if level != len(channel_mult) - 1:
+                out_ch = ch
+                self.input_blocks.append(
+                    TimestepEmbedSequential(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            down=True,
+                        )
+                        if resblock_updown
+                        else Downsample(
+                            ch, conv_resample, dims=dims, out_channels=out_ch
+                        )
+                    )
+                )
+                ch = out_ch
+                input_block_chans.append(ch)
+                ds *= 2
+                self._feature_size += ch
+
+        if num_head_channels == -1:
+            dim_head = ch // num_heads
+        else:
+            num_heads = ch // num_head_channels
+            dim_head = num_head_channels
+        if legacy:
+            #num_heads = 1
+            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+        self.middle_block = TimestepEmbedSequential(
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+            AttentionBlock(
+                ch,
+                use_checkpoint=use_checkpoint,
+                num_heads=num_heads,
+                num_head_channels=dim_head,
+                use_new_attention_order=use_new_attention_order,
+            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
+                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                            disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+                            use_checkpoint=use_checkpoint
+                        ),
+            ResBlock(
+                ch,
+                time_embed_dim,
+                dropout,
+                dims=dims,
+                use_checkpoint=use_checkpoint,
+                use_scale_shift_norm=use_scale_shift_norm,
+            ),
+        )
+        self._feature_size += ch
+
+        self.output_blocks = nn.ModuleList([])
+        for level, mult in list(enumerate(channel_mult))[::-1]:
+            for i in range(self.num_res_blocks[level] + 1):
+                ich = input_block_chans.pop()
+                layers = [
+                    ResBlock(
+                        ch + ich,
+                        time_embed_dim,
+                        dropout,
+                        out_channels=model_channels * mult,
+                        dims=dims,
+                        use_checkpoint=use_checkpoint,
+                        use_scale_shift_norm=use_scale_shift_norm,
+                    )
+                ]
+                ch = model_channels * mult
+                if ds in attention_resolutions:
+                    if num_head_channels == -1:
+                        dim_head = ch // num_heads
+                    else:
+                        num_heads = ch // num_head_channels
+                        dim_head = num_head_channels
+                    if legacy:
+                        #num_heads = 1
+                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+                    if exists(disable_self_attentions):
+                        disabled_sa = disable_self_attentions[level]
+                    else:
+                        disabled_sa = False
+
+                    if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+                        layers.append(
+                            AttentionBlock(
+                                ch,
+                                use_checkpoint=use_checkpoint,
+                                num_heads=num_heads_upsample,
+                                num_head_channels=dim_head,
+                                use_new_attention_order=use_new_attention_order,
+                            ) if not use_spatial_transformer else SpatialTransformer(
+                                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+                                disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+                                use_checkpoint=use_checkpoint
+                            )
+                        )
+                if level and i == self.num_res_blocks[level]:
+                    out_ch = ch
+                    layers.append(
+                        ResBlock(
+                            ch,
+                            time_embed_dim,
+                            dropout,
+                            out_channels=out_ch,
+                            dims=dims,
+                            use_checkpoint=use_checkpoint,
+                            use_scale_shift_norm=use_scale_shift_norm,
+                            up=True,
+                        )
+                        if resblock_updown
+                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+                    )
+                    ds //= 2
+                self.output_blocks.append(TimestepEmbedSequential(*layers))
+                self._feature_size += ch
+
+        self.out = nn.Sequential(
+            normalization(ch),
+            nn.SiLU(),
+            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+        )
+        if self.predict_codebook_ids:
+            self.id_predictor = nn.Sequential(
+            normalization(ch),
+            conv_nd(dims, model_channels, n_embed, 1),
+            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
+        )
+
+    def convert_to_fp16(self):
+        """
+        Convert the torso of the model to float16.
+        """
+        self.input_blocks.apply(convert_module_to_f16)
+        self.middle_block.apply(convert_module_to_f16)
+        self.output_blocks.apply(convert_module_to_f16)
+
+    def convert_to_fp32(self):
+        """
+        Convert the torso of the model to float32.
+        """
+        self.input_blocks.apply(convert_module_to_f32)
+        self.middle_block.apply(convert_module_to_f32)
+        self.output_blocks.apply(convert_module_to_f32)
+
+    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+        """
+        Apply the model to an input batch.
+        :param x: an [N x C x ...] Tensor of inputs.
+        :param timesteps: a 1-D batch of timesteps.
+        :param context: conditioning plugged in via crossattn
+        :param y: an [N] Tensor of labels, if class-conditional.
+        :return: an [N x C x ...] Tensor of outputs.
+        """
+        assert (y is not None) == (
+            self.num_classes is not None
+        ), "must specify y if and only if the model is class-conditional"
+        hs = []
+        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+        emb = self.time_embed(t_emb)
+
+        if self.num_classes is not None:
+            assert y.shape[0] == x.shape[0]
+            emb = emb + self.label_emb(y)
+
+        h = x.type(self.dtype)
+        for module in self.input_blocks:
+            h = module(h, emb, context)
+            hs.append(h)
+        h = self.middle_block(h, emb, context)
+        for module in self.output_blocks:
+            h = th.cat([h, hs.pop()], dim=1)
+            h = module(h, emb, context)
+        h = h.type(x.dtype)
+        if self.predict_codebook_ids:
+            return self.id_predictor(h)
+        else:
+            return self.out(h)
diff --git a/ldm/modules/diffusionmodules/upscaling.py b/ldm/modules/diffusionmodules/upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988
--- /dev/null
+++ b/ldm/modules/diffusionmodules/upscaling.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ldm.util import default
+
+
+class AbstractLowScaleModel(nn.Module):
+    # for concatenating a downsampled image to the latent representation
+    def __init__(self, noise_schedule_config=None):
+        super(AbstractLowScaleModel, self).__init__()
+        if noise_schedule_config is not None:
+            self.register_schedule(**noise_schedule_config)
+
+    def register_schedule(self, beta_schedule="linear", timesteps=1000,
+                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+        betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+                                   cosine_s=cosine_s)
+        alphas = 1. - betas
+        alphas_cumprod = np.cumprod(alphas, axis=0)
+        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+        timesteps, = betas.shape
+        self.num_timesteps = int(timesteps)
+        self.linear_start = linear_start
+        self.linear_end = linear_end
+        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+        to_torch = partial(torch.tensor, dtype=torch.float32)
+
+        self.register_buffer('betas', to_torch(betas))
+        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+        # calculations for diffusion q(x_t | x_{t-1}) and others
+        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+    def q_sample(self, x_start, t, noise=None):
+        noise = default(noise, lambda: torch.randn_like(x_start))
+        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+    def forward(self, x):
+        return x, None
+
+    def decode(self, x):
+        return x
+
+
+class SimpleImageConcat(AbstractLowScaleModel):
+    # no noise level conditioning
+    def __init__(self):
+        super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+        self.max_noise_level = 0
+
+    def forward(self, x):
+        # fix to constant noise level
+        return x, torch.zeros(x.shape[0], device=x.device).long()
+
+
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+    def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+        super().__init__(noise_schedule_config=noise_schedule_config)
+        self.max_noise_level = max_noise_level
+
+    def forward(self, x, noise_level=None):
+        if noise_level is None:
+            noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+        else:
+            assert isinstance(noise_level, torch.Tensor)
+        z = self.q_sample(x, noise_level)
+        return z, noise_level
+
+
+
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..637363dfe34799e70cfdbcd11445212df9d9ca1f
--- /dev/null
+++ b/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,270 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+    if schedule == "linear":
+        betas = (
+                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+        )
+
+    elif schedule == "cosine":
+        timesteps = (
+                torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+        )
+        alphas = timesteps / (1 + cosine_s) * np.pi / 2
+        alphas = torch.cos(alphas).pow(2)
+        alphas = alphas / alphas[0]
+        betas = 1 - alphas[1:] / alphas[:-1]
+        betas = np.clip(betas, a_min=0, a_max=0.999)
+
+    elif schedule == "sqrt_linear":
+        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+    elif schedule == "sqrt":
+        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+    else:
+        raise ValueError(f"schedule '{schedule}' unknown.")
+    return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+    if ddim_discr_method == 'uniform':
+        c = num_ddpm_timesteps // num_ddim_timesteps
+        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+    elif ddim_discr_method == 'quad':
+        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+    else:
+        raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+    # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+    # add one to get the final alpha values right (the ones from first scale to data during sampling)
+    steps_out = ddim_timesteps + 1
+    if verbose:
+        print(f'Selected timesteps for ddim sampler: {steps_out}')
+    return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+    # select alphas for computing the variance schedule
+    alphas = alphacums[ddim_timesteps]
+    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+    # according the the formula provided in https://arxiv.org/abs/2010.02502
+    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+    if verbose:
+        print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+        print(f'For the chosen value of eta, which is {eta}, '
+              f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+    return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+    """
+    Create a beta schedule that discretizes the given alpha_t_bar function,
+    which defines the cumulative product of (1-beta) over time from t = [0,1].
+    :param num_diffusion_timesteps: the number of betas to produce.
+    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+                      produces the cumulative product of (1-beta) up to that
+                      part of the diffusion process.
+    :param max_beta: the maximum beta to use; use values lower than 1 to
+                     prevent singularities.
+    """
+    betas = []
+    for i in range(num_diffusion_timesteps):
+        t1 = i / num_diffusion_timesteps
+        t2 = (i + 1) / num_diffusion_timesteps
+        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+    return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+    b, *_ = t.shape
+    out = a.gather(-1, t)
+    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+    """
+    Evaluate a function without caching intermediate activations, allowing for
+    reduced memory at the expense of extra compute in the backward pass.
+    :param func: the function to evaluate.
+    :param inputs: the argument sequence to pass to `func`.
+    :param params: a sequence of parameters `func` depends on but does not
+                   explicitly take as arguments.
+    :param flag: if False, disable gradient checkpointing.
+    """
+    if flag:
+        args = tuple(inputs) + tuple(params)
+        return CheckpointFunction.apply(func, len(inputs), *args)
+    else:
+        return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, run_function, length, *args):
+        ctx.run_function = run_function
+        ctx.input_tensors = list(args[:length])
+        ctx.input_params = list(args[length:])
+        ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+                                   "dtype": torch.get_autocast_gpu_dtype(),
+                                   "cache_enabled": torch.is_autocast_cache_enabled()}
+        with torch.no_grad():
+            output_tensors = ctx.run_function(*ctx.input_tensors)
+        return output_tensors
+
+    @staticmethod
+    def backward(ctx, *output_grads):
+        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+        with torch.enable_grad(), \
+                torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+            # Fixes a bug where the first op in run_function modifies the
+            # Tensor storage in place, which is not allowed for detach()'d
+            # Tensors.
+            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+            output_tensors = ctx.run_function(*shallow_copies)
+        input_grads = torch.autograd.grad(
+            output_tensors,
+            ctx.input_tensors + ctx.input_params,
+            output_grads,
+            allow_unused=True,
+        )
+        del ctx.input_tensors
+        del ctx.input_params
+        del output_tensors
+        return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+    """
+    Create sinusoidal timestep embeddings.
+    :param timesteps: a 1-D Tensor of N indices, one per batch element.
+                      These may be fractional.
+    :param dim: the dimension of the output.
+    :param max_period: controls the minimum frequency of the embeddings.
+    :return: an [N x dim] Tensor of positional embeddings.
+    """
+    if not repeat_only:
+        half = dim // 2
+        freqs = torch.exp(
+            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+        ).to(device=timesteps.device)
+        args = timesteps[:, None].float() * freqs[None]
+        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+        if dim % 2:
+            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+    else:
+        embedding = repeat(timesteps, 'b -> b d', d=dim)
+    return embedding
+
+
+def zero_module(module):
+    """
+    Zero out the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().zero_()
+    return module
+
+
+def scale_module(module, scale):
+    """
+    Scale the parameters of a module and return it.
+    """
+    for p in module.parameters():
+        p.detach().mul_(scale)
+    return module
+
+
+def mean_flat(tensor):
+    """
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+    """
+    Make a standard normalization layer.
+    :param channels: number of input channels.
+    :return: an nn.Module for normalization.
+    """
+    return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+    def forward(self, x):
+        return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+    def forward(self, x):
+        return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D convolution module.
+    """
+    if dims == 1:
+        return nn.Conv1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.Conv2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.Conv3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+    """
+    Create a linear module.
+    """
+    return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+    """
+    Create a 1D, 2D, or 3D average pooling module.
+    """
+    if dims == 1:
+        return nn.AvgPool1d(*args, **kwargs)
+    elif dims == 2:
+        return nn.AvgPool2d(*args, **kwargs)
+    elif dims == 3:
+        return nn.AvgPool3d(*args, **kwargs)
+    raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+    def __init__(self, c_concat_config, c_crossattn_config):
+        super().__init__()
+        self.concat_conditioner = instantiate_from_config(c_concat_config)
+        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+    def forward(self, c_concat, c_crossattn):
+        c_concat = self.concat_conditioner(c_concat)
+        c_crossattn = self.crossattn_conditioner(c_crossattn)
+        return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+    noise = lambda: torch.randn(shape, device=device)
+    return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb47d679a7575a653d7cc4f9a6a0a2d1a0fac7af
Binary files /dev/null and b/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4743d3dc8d7ab235423d655854fefab1c1cdecb2
Binary files /dev/null and b/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc differ
diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+    def sample(self):
+        raise NotImplementedError()
+
+    def mode(self):
+        raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+    def __init__(self, value):
+        self.value = value
+
+    def sample(self):
+        return self.value
+
+    def mode(self):
+        return self.value
+
+
+class DiagonalGaussianDistribution(object):
+    def __init__(self, parameters, deterministic=False):
+        self.parameters = parameters
+        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+        self.deterministic = deterministic
+        self.std = torch.exp(0.5 * self.logvar)
+        self.var = torch.exp(self.logvar)
+        if self.deterministic:
+            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+    def sample(self):
+        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+        return x
+
+    def kl(self, other=None):
+        if self.deterministic:
+            return torch.Tensor([0.])
+        else:
+            if other is None:
+                return 0.5 * torch.sum(torch.pow(self.mean, 2)
+                                       + self.var - 1.0 - self.logvar,
+                                       dim=[1, 2, 3])
+            else:
+                return 0.5 * torch.sum(
+                    torch.pow(self.mean - other.mean, 2) / other.var
+                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
+                    dim=[1, 2, 3])
+
+    def nll(self, sample, dims=[1,2,3]):
+        if self.deterministic:
+            return torch.Tensor([0.])
+        logtwopi = np.log(2.0 * np.pi)
+        return 0.5 * torch.sum(
+            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+            dim=dims)
+
+    def mode(self):
+        return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+    """
+    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+    Compute the KL divergence between two gaussians.
+    Shapes are automatically broadcasted, so batches can be compared to
+    scalars, among other use cases.
+    """
+    tensor = None
+    for obj in (mean1, logvar1, mean2, logvar2):
+        if isinstance(obj, torch.Tensor):
+            tensor = obj
+            break
+    assert tensor is not None, "at least one argument must be a Tensor"
+
+    # Force variances to be Tensors. Broadcasting helps convert scalars to
+    # Tensors, but it does not work for torch.exp().
+    logvar1, logvar2 = [
+        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+        for x in (logvar1, logvar2)
+    ]
+
+    return 0.5 * (
+        -1.0
+        + logvar2
+        - logvar1
+        + torch.exp(logvar1 - logvar2)
+        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+    )
diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4
--- /dev/null
+++ b/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+    def __init__(self, model, decay=0.9999, use_num_upates=True):
+        super().__init__()
+        if decay < 0.0 or decay > 1.0:
+            raise ValueError('Decay must be between 0 and 1')
+
+        self.m_name2s_name = {}
+        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+        else torch.tensor(-1, dtype=torch.int))
+
+        for name, p in model.named_parameters():
+            if p.requires_grad:
+                # remove as '.'-character is not allowed in buffers
+                s_name = name.replace('.', '')
+                self.m_name2s_name.update({name: s_name})
+                self.register_buffer(s_name, p.clone().detach().data)
+
+        self.collected_params = []
+
+    def reset_num_updates(self):
+        del self.num_updates
+        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+
+    def forward(self, model):
+        decay = self.decay
+
+        if self.num_updates >= 0:
+            self.num_updates += 1
+            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+        one_minus_decay = 1.0 - decay
+
+        with torch.no_grad():
+            m_param = dict(model.named_parameters())
+            shadow_params = dict(self.named_buffers())
+
+            for key in m_param:
+                if m_param[key].requires_grad:
+                    sname = self.m_name2s_name[key]
+                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+                else:
+                    assert not key in self.m_name2s_name
+
+    def copy_to(self, model):
+        m_param = dict(model.named_parameters())
+        shadow_params = dict(self.named_buffers())
+        for key in m_param:
+            if m_param[key].requires_grad:
+                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+            else:
+                assert not key in self.m_name2s_name
+
+    def store(self, parameters):
+        """
+        Save the current parameters for restoring later.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            temporarily stored.
+        """
+        self.collected_params = [param.clone() for param in parameters]
+
+    def restore(self, parameters):
+        """
+        Restore the parameters stored with the `store` method.
+        Useful to validate the model with EMA parameters without affecting the
+        original optimization process. Store the parameters before the
+        `copy_to` method. After validation (or model saving), use this to
+        restore the former parameters.
+        Args:
+          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+            updated with the stored parameters.
+        """
+        for c_param, param in zip(self.collected_params, parameters):
+            param.data.copy_(c_param.data)
diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07d6cf409a58232813e96e2733c63f5896a38372
Binary files /dev/null and b/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc differ
diff --git a/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4784df72bfd743ce4848894a03c5a4c9b2f6987e
Binary files /dev/null and b/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc differ
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4edd5496b9e668ea72a5be39db9cca94b6a42f9b
--- /dev/null
+++ b/ldm/modules/encoders/modules.py
@@ -0,0 +1,213 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+
+import open_clip
+from ldm.util import default, count_params
+
+
+class AbstractEncoder(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def encode(self, *args, **kwargs):
+        raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+    def encode(self, x):
+        return x
+
+
+class ClassEmbedder(nn.Module):
+    def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+        super().__init__()
+        self.key = key
+        self.embedding = nn.Embedding(n_classes, embed_dim)
+        self.n_classes = n_classes
+        self.ucg_rate = ucg_rate
+
+    def forward(self, batch, key=None, disable_dropout=False):
+        if key is None:
+            key = self.key
+        # this is for use in crossattn
+        c = batch[key][:, None]
+        if self.ucg_rate > 0. and not disable_dropout:
+            mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+            c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
+            c = c.long()
+        c = self.embedding(c)
+        return c
+
+    def get_unconditional_conditioning(self, bs, device="cuda"):
+        uc_class = self.n_classes - 1  # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+        uc = torch.ones((bs,), device=device) * uc_class
+        uc = {self.key: uc}
+        return uc
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+    """Uses the T5 transformer encoder for text"""
+    def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True):  # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+        super().__init__()
+        self.tokenizer = T5Tokenizer.from_pretrained(version)
+        self.transformer = T5EncoderModel.from_pretrained(version)
+        self.device = device
+        self.max_length = max_length   # TODO: typical value?
+        if freeze:
+            self.freeze()
+
+    def freeze(self):
+        self.transformer = self.transformer.eval()
+        #self.train = disabled_train
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+        tokens = batch_encoding["input_ids"].to(self.device)
+        outputs = self.transformer(input_ids=tokens)
+
+        z = outputs.last_hidden_state
+        return z
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+    """Uses the CLIP transformer encoder for text (from huggingface)"""
+    LAYERS = [
+        "last",
+        "pooled",
+        "hidden"
+    ]
+    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
+                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
+        super().__init__()
+        assert layer in self.LAYERS
+        self.tokenizer = CLIPTokenizer.from_pretrained(version)
+        self.transformer = CLIPTextModel.from_pretrained(version)
+        self.device = device
+        self.max_length = max_length
+        if freeze:
+            self.freeze()
+        self.layer = layer
+        self.layer_idx = layer_idx
+        if layer == "hidden":
+            assert layer_idx is not None
+            assert 0 <= abs(layer_idx) <= 12
+
+    def freeze(self):
+        self.transformer = self.transformer.eval()
+        #self.train = disabled_train
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+        tokens = batch_encoding["input_ids"].to(self.device)
+        outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
+        if self.layer == "last":
+            z = outputs.last_hidden_state
+        elif self.layer == "pooled":
+            z = outputs.pooler_output[:, None, :]
+        else:
+            z = outputs.hidden_states[self.layer_idx]
+        return z
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+    """
+    Uses the OpenCLIP transformer encoder for text
+    """
+    LAYERS = [
+        #"pooled",
+        "last",
+        "penultimate"
+    ]
+    def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+                 freeze=True, layer="last"):
+        super().__init__()
+        assert layer in self.LAYERS
+        model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+        del model.visual
+        self.model = model
+
+        self.device = device
+        self.max_length = max_length
+        if freeze:
+            self.freeze()
+        self.layer = layer
+        if self.layer == "last":
+            self.layer_idx = 0
+        elif self.layer == "penultimate":
+            self.layer_idx = 1
+        else:
+            raise NotImplementedError()
+
+    def freeze(self):
+        self.model = self.model.eval()
+        for param in self.parameters():
+            param.requires_grad = False
+
+    def forward(self, text):
+        tokens = open_clip.tokenize(text)
+        z = self.encode_with_transformer(tokens.to(self.device))
+        return z
+
+    def encode_with_transformer(self, text):
+        x = self.model.token_embedding(text)  # [batch_size, n_ctx, d_model]
+        x = x + self.model.positional_embedding
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.model.ln_final(x)
+        return x
+
+    def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
+        for i, r in enumerate(self.model.transformer.resblocks):
+            if i == len(self.model.transformer.resblocks) - self.layer_idx:
+                break
+            if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+                x = checkpoint(r, x, attn_mask)
+            else:
+                x = r(x, attn_mask=attn_mask)
+        return x
+
+    def encode(self, text):
+        return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+    def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
+                 clip_max_length=77, t5_max_length=77):
+        super().__init__()
+        self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
+        self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+        print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
+              f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
+
+    def encode(self, text):
+        return self(text)
+
+    def forward(self, text):
+        clip_z = self.clip_encoder.encode(text)
+        t5_z = self.t5_encoder.encode(text)
+        return [clip_z, t5_z]
+
+
diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+    '''
+    Args:
+        img: numpy image, WxH or WxHxC
+        sf: scale factor
+    Return:
+        cropped image
+    '''
+    w, h = img.shape[:2]
+    im = np.copy(img)
+    return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+    """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+    k_size = k.shape[0]
+    # Calculate the big kernels size
+    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+    # Loop over the small kernel to fill the big one
+    for r in range(k_size):
+        for c in range(k_size):
+            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+    # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+    crop = k_size // 2
+    cropped_big_k = big_k[crop:-crop, crop:-crop]
+    # Normalize to 1
+    return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+    """ generate an anisotropic Gaussian kernel
+    Args:
+        ksize : e.g., 15, kernel size
+        theta : [0,  pi], rotation angle range
+        l1    : [0.1,50], scaling of eigenvalues
+        l2    : [0.1,l1], scaling of eigenvalues
+        If l1 = l2, will get an isotropic Gaussian kernel.
+    Returns:
+        k     : kernel
+    """
+
+    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+    V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+    D = np.array([[l1, 0], [0, l2]])
+    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+    return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+    center = size / 2.0 + 0.5
+    k = np.zeros([size, size])
+    for y in range(size):
+        for x in range(size):
+            cy = y - center + 1
+            cx = x - center + 1
+            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+    k = k / np.sum(k)
+    return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+    """shift pixel for super-resolution with different scale factors
+    Args:
+        x: WxHxC or WxH
+        sf: scale factor
+        upper_left: shift direction
+    """
+    h, w = x.shape[:2]
+    shift = (sf - 1) * 0.5
+    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+    if upper_left:
+        x1 = xv + shift
+        y1 = yv + shift
+    else:
+        x1 = xv - shift
+        y1 = yv - shift
+
+    x1 = np.clip(x1, 0, w - 1)
+    y1 = np.clip(y1, 0, h - 1)
+
+    if x.ndim == 2:
+        x = interp2d(xv, yv, x)(x1, y1)
+    if x.ndim == 3:
+        for i in range(x.shape[-1]):
+            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+    return x
+
+
+def blur(x, k):
+    '''
+    x: image, NxcxHxW
+    k: kernel, Nx1xhxw
+    '''
+    n, c = x.shape[:2]
+    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+    k = k.repeat(1, c, 1, 1)
+    k = k.view(-1, 1, k.shape[2], k.shape[3])
+    x = x.view(1, -1, x.shape[2], x.shape[3])
+    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+    x = x.view(n, c, x.shape[2], x.shape[3])
+
+    return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+    """"
+    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+    # Kai Zhang
+    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var
+    # max_var = 2.5 * sf
+    """
+    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+    lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+    lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+    theta = np.random.rand() * np.pi  # random theta
+    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+    # Set COV matrix using Lambdas and Theta
+    LAMBDA = np.diag([lambda_1, lambda_2])
+    Q = np.array([[np.cos(theta), -np.sin(theta)],
+                  [np.sin(theta), np.cos(theta)]])
+    SIGMA = Q @ LAMBDA @ Q.T
+    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+    # Set expectation position (shifting kernel for aligned image)
+    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)
+    MU = MU[None, None, :, None]
+
+    # Create meshgrid for Gaussian
+    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+    Z = np.stack([X, Y], 2)[:, :, :, None]
+
+    # Calcualte Gaussian for every pixel of the kernel
+    ZZ = Z - MU
+    ZZ_t = ZZ.transpose(0, 1, 3, 2)
+    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+    # shift the kernel so it will be centered
+    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+    # Normalize the kernel and return
+    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+    kernel = raw_kernel / np.sum(raw_kernel)
+    return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+    hsize = [hsize, hsize]
+    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+    std = sigma
+    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+    arg = -(x * x + y * y) / (2 * std * std)
+    h = np.exp(arg)
+    h[h < scipy.finfo(float).eps * h.max()] = 0
+    sumh = h.sum()
+    if sumh != 0:
+        h = h / sumh
+    return h
+
+
+def fspecial_laplacian(alpha):
+    alpha = max([0, min([alpha, 1])])
+    h1 = alpha / (alpha + 1)
+    h2 = (1 - alpha) / (alpha + 1)
+    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+    h = np.array(h)
+    return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+    '''
+    python code from:
+    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+    '''
+    if filter_type == 'gaussian':
+        return fspecial_gaussian(*args, **kwargs)
+    if filter_type == 'laplacian':
+        return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+    '''
+    Args:
+        x: HxWxC image, [0, 1]
+        sf: down-scale factor
+    Return:
+        bicubicly downsampled LR image
+    '''
+    x = util.imresize_np(x, scale=1 / sf)
+    return x
+
+
+def srmd_degradation(x, k, sf=3):
+    ''' blur + bicubic downsampling
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2018learning,
+          title={Learning a single convolutional super-resolution network for multiple degradations},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={3262--3271},
+          year={2018}
+        }
+    '''
+    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'
+    x = bicubic_degradation(x, sf=sf)
+    return x
+
+
+def dpsr_degradation(x, k, sf=3):
+    ''' bicubic downsampling + blur
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2019deep,
+          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={1671--1681},
+          year={2019}
+        }
+    '''
+    x = bicubic_degradation(x, sf=sf)
+    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    return x
+
+
+def classical_degradation(x, k, sf=3):
+    ''' blur + downsampling
+    Args:
+        x: HxWxC image, [0, 1]/[0, 255]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    '''
+    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+    st = 0
+    return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+    """USM sharpening. borrowed from real-ESRGAN
+    Input image: I; Blurry image: B.
+    1. K = I + weight * (I - B)
+    2. Mask = 1 if abs(I - B) > threshold, else: 0
+    3. Blur mask:
+    4. Out = Mask * K + (1 - Mask) * I
+    Args:
+        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+        weight (float): Sharp weight. Default: 1.
+        radius (float): Kernel size of Gaussian blur. Default: 50.
+        threshold (int):
+    """
+    if radius % 2 == 0:
+        radius += 1
+    blur = cv2.GaussianBlur(img, (radius, radius), 0)
+    residual = img - blur
+    mask = np.abs(residual) * 255 > threshold
+    mask = mask.astype('float32')
+    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+    K = img + weight * residual
+    K = np.clip(K, 0, 1)
+    return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+    wd2 = 4.0 + sf
+    wd = 2.0 + 0.2 * sf
+    if random.random() < 0.5:
+        l1 = wd2 * random.random()
+        l2 = wd2 * random.random()
+        k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+    else:
+        k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+    img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+    return img
+
+
+def add_resize(img, sf=4):
+    rnum = np.random.rand()
+    if rnum > 0.8:  # up
+        sf1 = random.uniform(1, 2)
+    elif rnum < 0.7:  # down
+        sf1 = random.uniform(0.5 / sf, 1)
+    else:
+        sf1 = 1.0
+    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+    img = np.clip(img, 0.0, 1.0)
+
+    return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+#     noise_level = random.randint(noise_level1, noise_level2)
+#     rnum = np.random.rand()
+#     if rnum > 0.6:  # add color Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+#     elif rnum < 0.4:  # add grayscale Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+#     else:  # add  noise
+#         L = noise_level2 / 255.
+#         D = np.diag(np.random.rand(3))
+#         U = orth(np.random.rand(3, 3))
+#         conv = np.dot(np.dot(np.transpose(U), D), U)
+#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+#     img = np.clip(img, 0.0, 1.0)
+#     return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    rnum = np.random.rand()
+    if rnum > 0.6:  # add color Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:  # add grayscale Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:  # add  noise
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    img = np.clip(img, 0.0, 1.0)
+    rnum = random.random()
+    if rnum > 0.6:
+        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:
+        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_Poisson_noise(img):
+    img = np.clip((img * 255.0).round(), 0, 255) / 255.
+    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]
+    if random.random() < 0.5:
+        img = np.random.poisson(img * vals).astype(np.float32) / vals
+    else:
+        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+        img += noise_gray[:, :, np.newaxis]
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_JPEG_noise(img):
+    quality_factor = random.randint(30, 95)
+    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+    img = cv2.imdecode(encimg, 1)
+    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+    return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+    h, w = lq.shape[:2]
+    rnd_h = random.randint(0, h - lq_patchsize)
+    rnd_w = random.randint(0, w - lq_patchsize)
+    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+    return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = img.shape[:2]
+    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = img.shape[:2]
+
+    if h < lq_patchsize * sf or w < lq_patchsize * sf:
+        raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+    hq = img.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+                             interpolation=random.choice([1, 2, 3]))
+        else:
+            img = util.imresize_np(img, 1 / 2, True)
+        img = np.clip(img, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            img = add_blur(img, sf=sf)
+
+        elif i == 1:
+            img = add_blur(img, sf=sf)
+
+        elif i == 2:
+            a, b = img.shape[1], img.shape[0]
+            # downsample2
+            if random.random() < 0.75:
+                sf1 = random.uniform(1, 2 * sf)
+                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+                                 interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                img = img[0::sf, 0::sf, ...]  # nearest downsampling
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                img = add_JPEG_noise(img)
+
+        elif i == 6:
+            # add processed camera sensor noise
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    img = add_JPEG_noise(img)
+
+    # random crop
+    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+    return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    image = util.uint2single(image)
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = image.shape[:2]
+    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = image.shape[:2]
+
+    hq = image.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+                               interpolation=random.choice([1, 2, 3]))
+        else:
+            image = util.imresize_np(image, 1 / 2, True)
+        image = np.clip(image, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            image = add_blur(image, sf=sf)
+
+        elif i == 1:
+            image = add_blur(image, sf=sf)
+
+        elif i == 2:
+            a, b = image.shape[1], image.shape[0]
+            # downsample2
+            if random.random() < 0.75:
+                sf1 = random.uniform(1, 2 * sf)
+                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+                                   interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                image = image[0::sf, 0::sf, ...]  # nearest downsampling
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                image = add_JPEG_noise(image)
+
+        # elif i == 6:
+        #     # add processed camera sensor noise
+        #     if random.random() < isp_prob and isp_model is not None:
+        #         with torch.no_grad():
+        #             img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    image = add_JPEG_noise(image)
+    image = util.single2uint(image)
+    example = {"image":image}
+    return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+    """
+    This is an extended degradation model by combining
+    the degradation models of BSRGAN and Real-ESRGAN
+    ----------
+    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+    sf: scale factor
+    use_shuffle: the degradation shuffle
+    use_sharp: sharpening the img
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+
+    h1, w1 = img.shape[:2]
+    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = img.shape[:2]
+
+    if h < lq_patchsize * sf or w < lq_patchsize * sf:
+        raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+    if use_sharp:
+        img = add_sharpening(img)
+    hq = img.copy()
+
+    if random.random() < shuffle_prob:
+        shuffle_order = random.sample(range(13), 13)
+    else:
+        shuffle_order = list(range(13))
+        # local shuffle for noise, JPEG is always the last one
+        shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+        shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+    poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+    for i in shuffle_order:
+        if i == 0:
+            img = add_blur(img, sf=sf)
+        elif i == 1:
+            img = add_resize(img, sf=sf)
+        elif i == 2:
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+        elif i == 3:
+            if random.random() < poisson_prob:
+                img = add_Poisson_noise(img)
+        elif i == 4:
+            if random.random() < speckle_prob:
+                img = add_speckle_noise(img)
+        elif i == 5:
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+        elif i == 6:
+            img = add_JPEG_noise(img)
+        elif i == 7:
+            img = add_blur(img, sf=sf)
+        elif i == 8:
+            img = add_resize(img, sf=sf)
+        elif i == 9:
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+        elif i == 10:
+            if random.random() < poisson_prob:
+                img = add_Poisson_noise(img)
+        elif i == 11:
+            if random.random() < speckle_prob:
+                img = add_speckle_noise(img)
+        elif i == 12:
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+        else:
+            print('check the shuffle!')
+
+    # resize to desired size
+    img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+                     interpolation=random.choice([1, 2, 3]))
+
+    # add final JPEG compression noise
+    img = add_JPEG_noise(img)
+
+    # random crop
+    img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+    return img, hq
+
+
+if __name__ == '__main__':
+	print("hey")
+	img = util.imread_uint('utils/test.png', 3)
+	print(img)
+	img = util.uint2single(img)
+	print(img)
+	img = img[:448, :448]
+	h = img.shape[0] // 4
+	print("resizing to", h)
+	sf = 4
+	deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+	for i in range(20):
+		print(i)
+		img_lq = deg_fn(img)
+		print(img_lq)
+		img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+		print(img_lq.shape)
+		print("bicubic", img_lq_bicubic.shape)
+		print(img_hq.shape)
+		lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+		                        interpolation=0)
+		lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+		                        interpolation=0)
+		img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+		util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+def modcrop_np(img, sf):
+    '''
+    Args:
+        img: numpy image, WxH or WxHxC
+        sf: scale factor
+    Return:
+        cropped image
+    '''
+    w, h = img.shape[:2]
+    im = np.copy(img)
+    return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+    """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+    k_size = k.shape[0]
+    # Calculate the big kernels size
+    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+    # Loop over the small kernel to fill the big one
+    for r in range(k_size):
+        for c in range(k_size):
+            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+    # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+    crop = k_size // 2
+    cropped_big_k = big_k[crop:-crop, crop:-crop]
+    # Normalize to 1
+    return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+    """ generate an anisotropic Gaussian kernel
+    Args:
+        ksize : e.g., 15, kernel size
+        theta : [0,  pi], rotation angle range
+        l1    : [0.1,50], scaling of eigenvalues
+        l2    : [0.1,l1], scaling of eigenvalues
+        If l1 = l2, will get an isotropic Gaussian kernel.
+    Returns:
+        k     : kernel
+    """
+
+    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+    V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+    D = np.array([[l1, 0], [0, l2]])
+    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+    return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+    center = size / 2.0 + 0.5
+    k = np.zeros([size, size])
+    for y in range(size):
+        for x in range(size):
+            cy = y - center + 1
+            cx = x - center + 1
+            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+    k = k / np.sum(k)
+    return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+    """shift pixel for super-resolution with different scale factors
+    Args:
+        x: WxHxC or WxH
+        sf: scale factor
+        upper_left: shift direction
+    """
+    h, w = x.shape[:2]
+    shift = (sf - 1) * 0.5
+    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+    if upper_left:
+        x1 = xv + shift
+        y1 = yv + shift
+    else:
+        x1 = xv - shift
+        y1 = yv - shift
+
+    x1 = np.clip(x1, 0, w - 1)
+    y1 = np.clip(y1, 0, h - 1)
+
+    if x.ndim == 2:
+        x = interp2d(xv, yv, x)(x1, y1)
+    if x.ndim == 3:
+        for i in range(x.shape[-1]):
+            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+    return x
+
+
+def blur(x, k):
+    '''
+    x: image, NxcxHxW
+    k: kernel, Nx1xhxw
+    '''
+    n, c = x.shape[:2]
+    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+    k = k.repeat(1, c, 1, 1)
+    k = k.view(-1, 1, k.shape[2], k.shape[3])
+    x = x.view(1, -1, x.shape[2], x.shape[3])
+    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+    x = x.view(n, c, x.shape[2], x.shape[3])
+
+    return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+    """"
+    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+    # Kai Zhang
+    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var
+    # max_var = 2.5 * sf
+    """
+    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+    lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+    lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+    theta = np.random.rand() * np.pi  # random theta
+    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+    # Set COV matrix using Lambdas and Theta
+    LAMBDA = np.diag([lambda_1, lambda_2])
+    Q = np.array([[np.cos(theta), -np.sin(theta)],
+                  [np.sin(theta), np.cos(theta)]])
+    SIGMA = Q @ LAMBDA @ Q.T
+    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+    # Set expectation position (shifting kernel for aligned image)
+    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)
+    MU = MU[None, None, :, None]
+
+    # Create meshgrid for Gaussian
+    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+    Z = np.stack([X, Y], 2)[:, :, :, None]
+
+    # Calcualte Gaussian for every pixel of the kernel
+    ZZ = Z - MU
+    ZZ_t = ZZ.transpose(0, 1, 3, 2)
+    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+    # shift the kernel so it will be centered
+    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+    # Normalize the kernel and return
+    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+    kernel = raw_kernel / np.sum(raw_kernel)
+    return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+    hsize = [hsize, hsize]
+    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+    std = sigma
+    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+    arg = -(x * x + y * y) / (2 * std * std)
+    h = np.exp(arg)
+    h[h < scipy.finfo(float).eps * h.max()] = 0
+    sumh = h.sum()
+    if sumh != 0:
+        h = h / sumh
+    return h
+
+
+def fspecial_laplacian(alpha):
+    alpha = max([0, min([alpha, 1])])
+    h1 = alpha / (alpha + 1)
+    h2 = (1 - alpha) / (alpha + 1)
+    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+    h = np.array(h)
+    return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+    '''
+    python code from:
+    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+    '''
+    if filter_type == 'gaussian':
+        return fspecial_gaussian(*args, **kwargs)
+    if filter_type == 'laplacian':
+        return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+    '''
+    Args:
+        x: HxWxC image, [0, 1]
+        sf: down-scale factor
+    Return:
+        bicubicly downsampled LR image
+    '''
+    x = util.imresize_np(x, scale=1 / sf)
+    return x
+
+
+def srmd_degradation(x, k, sf=3):
+    ''' blur + bicubic downsampling
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2018learning,
+          title={Learning a single convolutional super-resolution network for multiple degradations},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={3262--3271},
+          year={2018}
+        }
+    '''
+    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'
+    x = bicubic_degradation(x, sf=sf)
+    return x
+
+
+def dpsr_degradation(x, k, sf=3):
+    ''' bicubic downsampling + blur
+    Args:
+        x: HxWxC image, [0, 1]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    Reference:
+        @inproceedings{zhang2019deep,
+          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+          pages={1671--1681},
+          year={2019}
+        }
+    '''
+    x = bicubic_degradation(x, sf=sf)
+    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    return x
+
+
+def classical_degradation(x, k, sf=3):
+    ''' blur + downsampling
+    Args:
+        x: HxWxC image, [0, 1]/[0, 255]
+        k: hxw, double
+        sf: down-scale factor
+    Return:
+        downsampled LR image
+    '''
+    x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+    st = 0
+    return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+    """USM sharpening. borrowed from real-ESRGAN
+    Input image: I; Blurry image: B.
+    1. K = I + weight * (I - B)
+    2. Mask = 1 if abs(I - B) > threshold, else: 0
+    3. Blur mask:
+    4. Out = Mask * K + (1 - Mask) * I
+    Args:
+        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+        weight (float): Sharp weight. Default: 1.
+        radius (float): Kernel size of Gaussian blur. Default: 50.
+        threshold (int):
+    """
+    if radius % 2 == 0:
+        radius += 1
+    blur = cv2.GaussianBlur(img, (radius, radius), 0)
+    residual = img - blur
+    mask = np.abs(residual) * 255 > threshold
+    mask = mask.astype('float32')
+    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+    K = img + weight * residual
+    K = np.clip(K, 0, 1)
+    return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+    wd2 = 4.0 + sf
+    wd = 2.0 + 0.2 * sf
+
+    wd2 = wd2/4
+    wd = wd/4
+
+    if random.random() < 0.5:
+        l1 = wd2 * random.random()
+        l2 = wd2 * random.random()
+        k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+    else:
+        k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+    img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+    return img
+
+
+def add_resize(img, sf=4):
+    rnum = np.random.rand()
+    if rnum > 0.8:  # up
+        sf1 = random.uniform(1, 2)
+    elif rnum < 0.7:  # down
+        sf1 = random.uniform(0.5 / sf, 1)
+    else:
+        sf1 = 1.0
+    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+    img = np.clip(img, 0.0, 1.0)
+
+    return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+#     noise_level = random.randint(noise_level1, noise_level2)
+#     rnum = np.random.rand()
+#     if rnum > 0.6:  # add color Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+#     elif rnum < 0.4:  # add grayscale Gaussian noise
+#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+#     else:  # add  noise
+#         L = noise_level2 / 255.
+#         D = np.diag(np.random.rand(3))
+#         U = orth(np.random.rand(3, 3))
+#         conv = np.dot(np.dot(np.transpose(U), D), U)
+#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+#     img = np.clip(img, 0.0, 1.0)
+#     return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    rnum = np.random.rand()
+    if rnum > 0.6:  # add color Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:  # add grayscale Gaussian noise
+        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:  # add  noise
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+    noise_level = random.randint(noise_level1, noise_level2)
+    img = np.clip(img, 0.0, 1.0)
+    rnum = random.random()
+    if rnum > 0.6:
+        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+    elif rnum < 0.4:
+        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+    else:
+        L = noise_level2 / 255.
+        D = np.diag(np.random.rand(3))
+        U = orth(np.random.rand(3, 3))
+        conv = np.dot(np.dot(np.transpose(U), D), U)
+        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_Poisson_noise(img):
+    img = np.clip((img * 255.0).round(), 0, 255) / 255.
+    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]
+    if random.random() < 0.5:
+        img = np.random.poisson(img * vals).astype(np.float32) / vals
+    else:
+        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+        img += noise_gray[:, :, np.newaxis]
+    img = np.clip(img, 0.0, 1.0)
+    return img
+
+
+def add_JPEG_noise(img):
+    quality_factor = random.randint(80, 95)
+    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+    img = cv2.imdecode(encimg, 1)
+    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+    return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+    h, w = lq.shape[:2]
+    rnd_h = random.randint(0, h - lq_patchsize)
+    rnd_w = random.randint(0, w - lq_patchsize)
+    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+    return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = img.shape[:2]
+    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = img.shape[:2]
+
+    if h < lq_patchsize * sf or w < lq_patchsize * sf:
+        raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+    hq = img.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+                             interpolation=random.choice([1, 2, 3]))
+        else:
+            img = util.imresize_np(img, 1 / 2, True)
+        img = np.clip(img, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            img = add_blur(img, sf=sf)
+
+        elif i == 1:
+            img = add_blur(img, sf=sf)
+
+        elif i == 2:
+            a, b = img.shape[1], img.shape[0]
+            # downsample2
+            if random.random() < 0.75:
+                sf1 = random.uniform(1, 2 * sf)
+                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+                                 interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                img = img[0::sf, 0::sf, ...]  # nearest downsampling
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            img = np.clip(img, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                img = add_JPEG_noise(img)
+
+        elif i == 6:
+            # add processed camera sensor noise
+            if random.random() < isp_prob and isp_model is not None:
+                with torch.no_grad():
+                    img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    img = add_JPEG_noise(img)
+
+    # random crop
+    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+    return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
+    """
+    This is the degradation model of BSRGAN from the paper
+    "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+    ----------
+    sf: scale factor
+    isp_model: camera ISP model
+    Returns
+    -------
+    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+    """
+    image = util.uint2single(image)
+    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+    sf_ori = sf
+
+    h1, w1 = image.shape[:2]
+    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop
+    h, w = image.shape[:2]
+
+    hq = image.copy()
+
+    if sf == 4 and random.random() < scale2_prob:  # downsample1
+        if np.random.rand() < 0.5:
+            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+                               interpolation=random.choice([1, 2, 3]))
+        else:
+            image = util.imresize_np(image, 1 / 2, True)
+        image = np.clip(image, 0.0, 1.0)
+        sf = 2
+
+    shuffle_order = random.sample(range(7), 7)
+    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+    if idx1 > idx2:  # keep downsample3 last
+        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+    for i in shuffle_order:
+
+        if i == 0:
+            image = add_blur(image, sf=sf)
+
+        # elif i == 1:
+        #     image = add_blur(image, sf=sf)
+
+        if i == 0:
+            pass
+
+        elif i == 2:
+            a, b = image.shape[1], image.shape[0]
+            # downsample2
+            if random.random() < 0.8:
+                sf1 = random.uniform(1, 2 * sf)
+                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+                                   interpolation=random.choice([1, 2, 3]))
+            else:
+                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+                k_shifted = shift_pixel(k, sf)
+                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel
+                image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+                image = image[0::sf, 0::sf, ...]  # nearest downsampling
+
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 3:
+            # downsample3
+            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+            image = np.clip(image, 0.0, 1.0)
+
+        elif i == 4:
+            # add Gaussian noise
+            image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+        elif i == 5:
+            # add JPEG noise
+            if random.random() < jpeg_prob:
+                image = add_JPEG_noise(image)
+        #
+        # elif i == 6:
+        #     # add processed camera sensor noise
+        #     if random.random() < isp_prob and isp_model is not None:
+        #         with torch.no_grad():
+        #             img, hq = isp_model.forward(img.copy(), hq)
+
+    # add final JPEG compression noise
+    image = add_JPEG_noise(image)
+    image = util.single2uint(image)
+    if up:
+        image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC)  # todo: random, as above? want to condition on it then
+    example = {"image": image}
+    return example
+
+
+
+
+if __name__ == '__main__':
+    print("hey")
+    img = util.imread_uint('utils/test.png', 3)
+    img = img[:448, :448]
+    h = img.shape[0] // 4
+    print("resizing to", h)
+    sf = 4
+    deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+    for i in range(20):
+        print(i)
+        img_hq = img
+        img_lq = deg_fn(img)["image"]
+        img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+        print(img_lq)
+        img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+        print(img_lq.shape)
+        print("bicubic", img_lq_bicubic.shape)
+        print(img_hq.shape)
+        lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+                                interpolation=0)
+        lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+                                        (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+                                        interpolation=0)
+        img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+        util.imsave(img_concat, str(i) + '.png')
diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6
Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ
diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt   # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+    return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+    plt.figure(figsize=figsize)
+    plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+    if title:
+        plt.title(title)
+    if cbar:
+        plt.colorbar()
+    plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+    plt.figure(figsize=figsize)
+    ax3 = plt.axes(projection='3d')
+
+    w, h = Z.shape[:2]
+    xx = np.arange(0,w,1)
+    yy = np.arange(0,h,1)
+    X, Y = np.meshgrid(xx, yy)
+    ax3.plot_surface(X,Y,Z,cmap=cmap)
+    #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+    plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+    paths = None  # return None if dataroot is None
+    if dataroot is not None:
+        paths = sorted(_get_paths_from_images(dataroot))
+    return paths
+
+
+def _get_paths_from_images(path):
+    assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+    images = []
+    for dirpath, _, fnames in sorted(os.walk(path)):
+        for fname in sorted(fnames):
+            if is_image_file(fname):
+                img_path = os.path.join(dirpath, fname)
+                images.append(img_path)
+    assert images, '{:s} has no valid image file'.format(path)
+    return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images 
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+    w, h = img.shape[:2]
+    patches = []
+    if w > p_max and h > p_max:
+        w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+        h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+        w1.append(w-p_size)
+        h1.append(h-p_size)
+#        print(w1)
+#        print(h1)
+        for i in w1:
+            for j in h1:
+                patches.append(img[i:i+p_size, j:j+p_size,:])
+    else:
+        patches.append(img)
+
+    return patches
+
+
+def imssave(imgs, img_path):
+    """
+    imgs: list, N images of size WxHxC
+    """
+    img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+    for i, img in enumerate(imgs):
+        if img.ndim == 3:
+            img = img[:, :, [2, 1, 0]]
+        new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+        cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+    """
+    split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+    and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+    will be splitted.
+    Args:
+        original_dataroot:
+        taget_dataroot:
+        p_size: size of small images
+        p_overlap: patch size in training is a good choice
+        p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+    """
+    paths = get_image_paths(original_dataroot)
+    for img_path in paths:
+        # img_name, ext = os.path.splitext(os.path.basename(img_path))
+        img = imread_uint(img_path, n_channels=n_channels)
+        patches = patches_from_image(img, p_size, p_overlap, p_max)
+        imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+        #if original_dataroot == taget_dataroot:
+        #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def mkdirs(paths):
+    if isinstance(paths, str):
+        mkdir(paths)
+    else:
+        for path in paths:
+            mkdir(path)
+
+
+def mkdir_and_rename(path):
+    if os.path.exists(path):
+        new_name = path + '_archived_' + get_timestamp()
+        print('Path already exists. Rename it to [{:s}]'.format(new_name))
+        os.rename(path, new_name)
+    os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+    #  input: path
+    # output: HxWx3(RGB or GGG), or HxWx1 (G)
+    if n_channels == 1:
+        img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE
+        img = np.expand_dims(img, axis=2)  # HxWx1
+    elif n_channels == 3:
+        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G
+        if img.ndim == 2:
+            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG
+        else:
+            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB
+    return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+    img = np.squeeze(img)
+    if img.ndim == 3:
+        img = img[:, :, [2, 1, 0]]
+    cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+    img = np.squeeze(img)
+    if img.ndim == 3:
+        img = img[:, :, [2, 1, 0]]
+    cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+    # read image by cv2
+    # return: Numpy float32, HWC, BGR, [0,1]
+    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # cv2.IMREAD_GRAYSCALE
+    img = img.astype(np.float32) / 255.
+    if img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    # some images have 4 channels
+    if img.shape[2] > 3:
+        img = img[:, :, :3]
+    return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <--->  numpy(unit)
+# numpy(single) <--->  tensor
+# numpy(unit)   <--->  tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <--->  numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+    return np.float32(img/255.)
+
+
+def single2uint(img):
+
+    return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+    return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+    return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <--->  tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+    if img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+    if img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+    img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+    if img.ndim == 3:
+        img = np.transpose(img, (1, 2, 0))
+    return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <--->  tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+    img = img.data.squeeze().float().cpu().numpy()
+    if img.ndim == 3:
+        img = np.transpose(img, (1, 2, 0))
+
+    return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+    img = img.data.squeeze().float().cpu().numpy()
+    if img.ndim == 3:
+        img = np.transpose(img, (1, 2, 0))
+    elif img.ndim == 2:
+        img = np.expand_dims(img, axis=2)
+    return img
+
+
+def single2tensor5(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+    '''
+    Converts a torch Tensor into an image Numpy array of BGR channel order
+    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+    '''
+    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # squeeze first, then clamp
+    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
+    n_dim = tensor.dim()
+    if n_dim == 4:
+        n_img = len(tensor)
+        img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
+    elif n_dim == 3:
+        img_np = tensor.numpy()
+        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
+    elif n_dim == 2:
+        img_np = tensor.numpy()
+    else:
+        raise TypeError(
+            'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+    if out_type == np.uint8:
+        img_np = (img_np * 255.0).round()
+        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+    return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+    '''Kai Zhang (github: https://github.com/cszn)
+    '''
+    if mode == 0:
+        return img
+    elif mode == 1:
+        return np.flipud(np.rot90(img))
+    elif mode == 2:
+        return np.flipud(img)
+    elif mode == 3:
+        return np.rot90(img, k=3)
+    elif mode == 4:
+        return np.flipud(np.rot90(img, k=2))
+    elif mode == 5:
+        return np.rot90(img)
+    elif mode == 6:
+        return np.rot90(img, k=2)
+    elif mode == 7:
+        return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+    '''Kai Zhang (github: https://github.com/cszn)
+    '''
+    if mode == 0:
+        return img
+    elif mode == 1:
+        return img.rot90(1, [2, 3]).flip([2])
+    elif mode == 2:
+        return img.flip([2])
+    elif mode == 3:
+        return img.rot90(3, [2, 3])
+    elif mode == 4:
+        return img.rot90(2, [2, 3]).flip([2])
+    elif mode == 5:
+        return img.rot90(1, [2, 3])
+    elif mode == 6:
+        return img.rot90(2, [2, 3])
+    elif mode == 7:
+        return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+    '''Kai Zhang (github: https://github.com/cszn)
+    '''
+    img_size = img.size()
+    img_np = img.data.cpu().numpy()
+    if len(img_size) == 3:
+        img_np = np.transpose(img_np, (1, 2, 0))
+    elif len(img_size) == 4:
+        img_np = np.transpose(img_np, (2, 3, 1, 0))
+    img_np = augment_img(img_np, mode=mode)
+    img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+    if len(img_size) == 3:
+        img_tensor = img_tensor.permute(2, 0, 1)
+    elif len(img_size) == 4:
+        img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+    return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+    if mode == 0:
+        return img
+    elif mode == 1:
+        return img.transpose(1, 0, 2)
+    elif mode == 2:
+        return img[::-1, :, :]
+    elif mode == 3:
+        img = img[::-1, :, :]
+        img = img.transpose(1, 0, 2)
+        return img
+    elif mode == 4:
+        return img[:, ::-1, :]
+    elif mode == 5:
+        img = img[:, ::-1, :]
+        img = img.transpose(1, 0, 2)
+        return img
+    elif mode == 6:
+        img = img[:, ::-1, :]
+        img = img[::-1, :, :]
+        return img
+    elif mode == 7:
+        img = img[:, ::-1, :]
+        img = img[::-1, :, :]
+        img = img.transpose(1, 0, 2)
+        return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+    # horizontal flip OR rotate
+    hflip = hflip and random.random() < 0.5
+    vflip = rot and random.random() < 0.5
+    rot90 = rot and random.random() < 0.5
+
+    def _augment(img):
+        if hflip:
+            img = img[:, ::-1, :]
+        if vflip:
+            img = img[::-1, :, :]
+        if rot90:
+            img = img.transpose(1, 0, 2)
+        return img
+
+    return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+    # img_in: Numpy, HWC or HW
+    img = np.copy(img_in)
+    if img.ndim == 2:
+        H, W = img.shape
+        H_r, W_r = H % scale, W % scale
+        img = img[:H - H_r, :W - W_r]
+    elif img.ndim == 3:
+        H, W, C = img.shape
+        H_r, W_r = H % scale, W % scale
+        img = img[:H - H_r, :W - W_r, :]
+    else:
+        raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+    return img
+
+
+def shave(img_in, border=0):
+    # img_in: Numpy, HWC or HW
+    img = np.copy(img_in)
+    h, w = img.shape[:2]
+    img = img[border:h-border, border:w-border]
+    return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+    '''same as matlab rgb2ycbcr
+    only_y: only return Y channel
+    Input:
+        uint8, [0, 255]
+        float, [0, 1]
+    '''
+    in_img_type = img.dtype
+    img.astype(np.float32)
+    if in_img_type != np.uint8:
+        img *= 255.
+    # convert
+    if only_y:
+        rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+    else:
+        rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+                              [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+    if in_img_type == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+    '''same as matlab ycbcr2rgb
+    Input:
+        uint8, [0, 255]
+        float, [0, 1]
+    '''
+    in_img_type = img.dtype
+    img.astype(np.float32)
+    if in_img_type != np.uint8:
+        img *= 255.
+    # convert
+    rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+                          [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+    if in_img_type == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+    '''bgr version of rgb2ycbcr
+    only_y: only return Y channel
+    Input:
+        uint8, [0, 255]
+        float, [0, 1]
+    '''
+    in_img_type = img.dtype
+    img.astype(np.float32)
+    if in_img_type != np.uint8:
+        img *= 255.
+    # convert
+    if only_y:
+        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+    else:
+        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+    if in_img_type == np.uint8:
+        rlt = rlt.round()
+    else:
+        rlt /= 255.
+    return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+    # conversion among BGR, gray and y
+    if in_c == 3 and tar_type == 'gray':  # BGR to gray
+        gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+        return [np.expand_dims(img, axis=2) for img in gray_list]
+    elif in_c == 3 and tar_type == 'y':  # BGR to y
+        y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+        return [np.expand_dims(img, axis=2) for img in y_list]
+    elif in_c == 1 and tar_type == 'RGB':  # gray/y to BGR
+        return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+    else:
+        return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+    # img1 and img2 have range [0, 255]
+    #img1 = img1.squeeze()
+    #img2 = img2.squeeze()
+    if not img1.shape == img2.shape:
+        raise ValueError('Input images must have the same dimensions.')
+    h, w = img1.shape[:2]
+    img1 = img1[border:h-border, border:w-border]
+    img2 = img2[border:h-border, border:w-border]
+
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+    mse = np.mean((img1 - img2)**2)
+    if mse == 0:
+        return float('inf')
+    return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+    '''calculate SSIM
+    the same outputs as MATLAB's
+    img1, img2: [0, 255]
+    '''
+    #img1 = img1.squeeze()
+    #img2 = img2.squeeze()
+    if not img1.shape == img2.shape:
+        raise ValueError('Input images must have the same dimensions.')
+    h, w = img1.shape[:2]
+    img1 = img1[border:h-border, border:w-border]
+    img2 = img2[border:h-border, border:w-border]
+
+    if img1.ndim == 2:
+        return ssim(img1, img2)
+    elif img1.ndim == 3:
+        if img1.shape[2] == 3:
+            ssims = []
+            for i in range(3):
+                ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+            return np.array(ssims).mean()
+        elif img1.shape[2] == 1:
+            return ssim(np.squeeze(img1), np.squeeze(img2))
+    else:
+        raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+    C1 = (0.01 * 255)**2
+    C2 = (0.03 * 255)**2
+
+    img1 = img1.astype(np.float64)
+    img2 = img2.astype(np.float64)
+    kernel = cv2.getGaussianKernel(11, 1.5)
+    window = np.outer(kernel, kernel.transpose())
+
+    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
+    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+    mu1_sq = mu1**2
+    mu2_sq = mu2**2
+    mu1_mu2 = mu1 * mu2
+    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+                                                            (sigma1_sq + sigma2_sq + C2))
+    return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+    absx = torch.abs(x)
+    absx2 = absx**2
+    absx3 = absx**3
+    return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+        (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+    if (scale < 1) and (antialiasing):
+        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+        kernel_width = kernel_width / scale
+
+    # Output-space coordinates
+    x = torch.linspace(1, out_length, out_length)
+
+    # Input-space coordinates. Calculate the inverse mapping such that 0.5
+    # in output space maps to 0.5 in input space, and 0.5+scale in output
+    # space maps to 1.5 in input space.
+    u = x / scale + 0.5 * (1 - 1 / scale)
+
+    # What is the left-most pixel that can be involved in the computation?
+    left = torch.floor(u - kernel_width / 2)
+
+    # What is the maximum number of pixels that can be involved in the
+    # computation?  Note: it's OK to use an extra pixel here; if the
+    # corresponding weights are all zero, it will be eliminated at the end
+    # of this function.
+    P = math.ceil(kernel_width) + 2
+
+    # The indices of the input pixels involved in computing the k-th output
+    # pixel are in row k of the indices matrix.
+    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+        1, P).expand(out_length, P)
+
+    # The weights used to compute the k-th output pixel are in row k of the
+    # weights matrix.
+    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+    # apply cubic kernel
+    if (scale < 1) and (antialiasing):
+        weights = scale * cubic(distance_to_center * scale)
+    else:
+        weights = cubic(distance_to_center)
+    # Normalize the weights matrix so that each row sums to 1.
+    weights_sum = torch.sum(weights, 1).view(out_length, 1)
+    weights = weights / weights_sum.expand(out_length, P)
+
+    # If a column in weights is all zero, get rid of it. only consider the first and last column.
+    weights_zero_tmp = torch.sum((weights == 0), 0)
+    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 1, P - 2)
+        weights = weights.narrow(1, 1, P - 2)
+    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+        indices = indices.narrow(1, 0, P - 2)
+        weights = weights.narrow(1, 0, P - 2)
+    weights = weights.contiguous()
+    indices = indices.contiguous()
+    sym_len_s = -indices.min() + 1
+    sym_len_e = indices.max() - in_length
+    indices = indices + sym_len_s - 1
+    return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+    # Now the scale should be the same for H and W
+    # input: img: pytorch tensor, CHW or HW [0,1]
+    # output: CHW or HW [0,1] w/o round
+    need_squeeze = True if img.dim() == 2 else False
+    if need_squeeze:
+        img.unsqueeze_(0)
+    in_C, in_H, in_W = img.size()
+    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+    kernel_width = 4
+    kernel = 'cubic'
+
+    # Return the desired dimension order for performing the resize.  The
+    # strategy is to perform the resize first along the dimension with the
+    # smallest scale factor.
+    # Now we do not support this.
+
+    # get weights and indices
+    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+        in_H, out_H, scale, kernel, kernel_width, antialiasing)
+    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+        in_W, out_W, scale, kernel, kernel_width, antialiasing)
+    # process H dimension
+    # symmetric copying
+    img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+    img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+    sym_patch = img[:, :sym_len_Hs, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+    sym_patch = img[:, -sym_len_He:, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+    out_1 = torch.FloatTensor(in_C, out_H, in_W)
+    kernel_width = weights_H.size(1)
+    for i in range(out_H):
+        idx = int(indices_H[i][0])
+        for j in range(out_C):
+            out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+    # process W dimension
+    # symmetric copying
+    out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+    out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+    sym_patch = out_1[:, :, :sym_len_Ws]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+    sym_patch = out_1[:, :, -sym_len_We:]
+    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(2, inv_idx)
+    out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+    out_2 = torch.FloatTensor(in_C, out_H, out_W)
+    kernel_width = weights_W.size(1)
+    for i in range(out_W):
+        idx = int(indices_W[i][0])
+        for j in range(out_C):
+            out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+    if need_squeeze:
+        out_2.squeeze_()
+    return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+    # Now the scale should be the same for H and W
+    # input: img: Numpy, HWC or HW [0,1]
+    # output: HWC or HW [0,1] w/o round
+    img = torch.from_numpy(img)
+    need_squeeze = True if img.dim() == 2 else False
+    if need_squeeze:
+        img.unsqueeze_(2)
+
+    in_H, in_W, in_C = img.size()
+    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+    kernel_width = 4
+    kernel = 'cubic'
+
+    # Return the desired dimension order for performing the resize.  The
+    # strategy is to perform the resize first along the dimension with the
+    # smallest scale factor.
+    # Now we do not support this.
+
+    # get weights and indices
+    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+        in_H, out_H, scale, kernel, kernel_width, antialiasing)
+    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+        in_W, out_W, scale, kernel, kernel_width, antialiasing)
+    # process H dimension
+    # symmetric copying
+    img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+    img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+    sym_patch = img[:sym_len_Hs, :, :]
+    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(0, inv_idx)
+    img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+    sym_patch = img[-sym_len_He:, :, :]
+    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(0, inv_idx)
+    img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+    out_1 = torch.FloatTensor(out_H, in_W, in_C)
+    kernel_width = weights_H.size(1)
+    for i in range(out_H):
+        idx = int(indices_H[i][0])
+        for j in range(out_C):
+            out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+    # process W dimension
+    # symmetric copying
+    out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+    out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+    sym_patch = out_1[:, :sym_len_Ws, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+    sym_patch = out_1[:, -sym_len_We:, :]
+    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+    sym_patch_inv = sym_patch.index_select(1, inv_idx)
+    out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+    out_2 = torch.FloatTensor(out_H, out_W, in_C)
+    kernel_width = weights_W.size(1)
+    for i in range(out_W):
+        idx = int(indices_W[i][0])
+        for j in range(out_C):
+            out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+    if need_squeeze:
+        out_2.squeeze_()
+
+    return out_2.numpy()
+
+
+if __name__ == '__main__':
+    print('---')
+#    img = imread_uint('test.bmp', 3)
+#    img = uint2single(img)
+#    img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/ldm/modules/midas/__init__.py b/ldm/modules/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/midas/api.py b/ldm/modules/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c
--- /dev/null
+++ b/ldm/modules/midas/api.py
@@ -0,0 +1,170 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
+from ldm.modules.midas.midas.midas_net import MidasNet
+from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
+from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+
+ISL_PATHS = {
+    "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
+    "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
+    "midas_v21": "",
+    "midas_v21_small": "",
+}
+
+
+def disabled_train(self, mode=True):
+    """Overwrite model.train with this function to make sure train/eval mode
+    does not change anymore."""
+    return self
+
+
+def load_midas_transform(model_type):
+    # https://github.com/isl-org/MiDaS/blob/master/run.py
+    # load transform only
+    if model_type == "dpt_large":  # DPT-Large
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "dpt_hybrid":  # DPT-Hybrid
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "midas_v21":
+        net_w, net_h = 384, 384
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+    elif model_type == "midas_v21_small":
+        net_w, net_h = 256, 256
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+    else:
+        assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+    transform = Compose(
+        [
+            Resize(
+                net_w,
+                net_h,
+                resize_target=None,
+                keep_aspect_ratio=True,
+                ensure_multiple_of=32,
+                resize_method=resize_mode,
+                image_interpolation_method=cv2.INTER_CUBIC,
+            ),
+            normalization,
+            PrepareForNet(),
+        ]
+    )
+
+    return transform
+
+
+def load_model(model_type):
+    # https://github.com/isl-org/MiDaS/blob/master/run.py
+    # load network
+    model_path = ISL_PATHS[model_type]
+    if model_type == "dpt_large":  # DPT-Large
+        model = DPTDepthModel(
+            path=model_path,
+            backbone="vitl16_384",
+            non_negative=True,
+        )
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "dpt_hybrid":  # DPT-Hybrid
+        model = DPTDepthModel(
+            path=model_path,
+            backbone="vitb_rn50_384",
+            non_negative=True,
+        )
+        net_w, net_h = 384, 384
+        resize_mode = "minimal"
+        normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+    elif model_type == "midas_v21":
+        model = MidasNet(model_path, non_negative=True)
+        net_w, net_h = 384, 384
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
+    elif model_type == "midas_v21_small":
+        model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+                               non_negative=True, blocks={'expand': True})
+        net_w, net_h = 256, 256
+        resize_mode = "upper_bound"
+        normalization = NormalizeImage(
+            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+        )
+
+    else:
+        print(f"model_type '{model_type}' not implemented, use: --model_type large")
+        assert False
+
+    transform = Compose(
+        [
+            Resize(
+                net_w,
+                net_h,
+                resize_target=None,
+                keep_aspect_ratio=True,
+                ensure_multiple_of=32,
+                resize_method=resize_mode,
+                image_interpolation_method=cv2.INTER_CUBIC,
+            ),
+            normalization,
+            PrepareForNet(),
+        ]
+    )
+
+    return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+    MODEL_TYPES_TORCH_HUB = [
+        "DPT_Large",
+        "DPT_Hybrid",
+        "MiDaS_small"
+    ]
+    MODEL_TYPES_ISL = [
+        "dpt_large",
+        "dpt_hybrid",
+        "midas_v21",
+        "midas_v21_small",
+    ]
+
+    def __init__(self, model_type):
+        super().__init__()
+        assert (model_type in self.MODEL_TYPES_ISL)
+        model, _ = load_model(model_type)
+        self.model = model
+        self.model.train = disabled_train
+
+    def forward(self, x):
+        # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
+        # NOTE: we expect that the correct transform has been called during dataloading.
+        with torch.no_grad():
+            prediction = self.model(x)
+            prediction = torch.nn.functional.interpolate(
+                prediction.unsqueeze(1),
+                size=x.shape[2:],
+                mode="bicubic",
+                align_corners=False,
+            )
+        assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
+        return prediction
+
diff --git a/ldm/modules/midas/midas/__init__.py b/ldm/modules/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/midas/midas/base_model.py b/ldm/modules/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/ldm/modules/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+    def load(self, path):
+        """Load model from file.
+
+        Args:
+            path (str): file path
+        """
+        parameters = torch.load(path, map_location=torch.device('cpu'))
+
+        if "optimizer" in parameters:
+            parameters = parameters["model"]
+
+        self.load_state_dict(parameters)
diff --git a/ldm/modules/midas/midas/blocks.py b/ldm/modules/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/ldm/modules/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+    _make_pretrained_vitb_rn50_384,
+    _make_pretrained_vitl16_384,
+    _make_pretrained_vitb16_384,
+    forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+    if backbone == "vitl16_384":
+        pretrained = _make_pretrained_vitl16_384(
+            use_pretrained, hooks=hooks, use_readout=use_readout
+        )
+        scratch = _make_scratch(
+            [256, 512, 1024, 1024], features, groups=groups, expand=expand
+        )  # ViT-L/16 - 85.0% Top1 (backbone)
+    elif backbone == "vitb_rn50_384":
+        pretrained = _make_pretrained_vitb_rn50_384(
+            use_pretrained,
+            hooks=hooks,
+            use_vit_only=use_vit_only,
+            use_readout=use_readout,
+        )
+        scratch = _make_scratch(
+            [256, 512, 768, 768], features, groups=groups, expand=expand
+        )  # ViT-H/16 - 85.0% Top1 (backbone)
+    elif backbone == "vitb16_384":
+        pretrained = _make_pretrained_vitb16_384(
+            use_pretrained, hooks=hooks, use_readout=use_readout
+        )
+        scratch = _make_scratch(
+            [96, 192, 384, 768], features, groups=groups, expand=expand
+        )  # ViT-B/16 - 84.6% Top1 (backbone)
+    elif backbone == "resnext101_wsl":
+        pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+        scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand)     # efficientnet_lite3  
+    elif backbone == "efficientnet_lite3":
+        pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+        scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand)  # efficientnet_lite3     
+    else:
+        print(f"Backbone '{backbone}' not implemented")
+        assert False
+        
+    return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+    scratch = nn.Module()
+
+    out_shape1 = out_shape
+    out_shape2 = out_shape
+    out_shape3 = out_shape
+    out_shape4 = out_shape
+    if expand==True:
+        out_shape1 = out_shape
+        out_shape2 = out_shape*2
+        out_shape3 = out_shape*4
+        out_shape4 = out_shape*8
+
+    scratch.layer1_rn = nn.Conv2d(
+        in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer2_rn = nn.Conv2d(
+        in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer3_rn = nn.Conv2d(
+        in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+    scratch.layer4_rn = nn.Conv2d(
+        in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+    )
+
+    return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+    efficientnet = torch.hub.load(
+        "rwightman/gen-efficientnet-pytorch",
+        "tf_efficientnet_lite3",
+        pretrained=use_pretrained,
+        exportable=exportable
+    )
+    return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+    pretrained = nn.Module()
+
+    pretrained.layer1 = nn.Sequential(
+        effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+    )
+    pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+    pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+    pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+    return pretrained
+    
+
+def _make_resnet_backbone(resnet):
+    pretrained = nn.Module()
+    pretrained.layer1 = nn.Sequential(
+        resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+    )
+
+    pretrained.layer2 = resnet.layer2
+    pretrained.layer3 = resnet.layer3
+    pretrained.layer4 = resnet.layer4
+
+    return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+    resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+    return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+    """Interpolation module.
+    """
+
+    def __init__(self, scale_factor, mode, align_corners=False):
+        """Init.
+
+        Args:
+            scale_factor (float): scaling
+            mode (str): interpolation mode
+        """
+        super(Interpolate, self).__init__()
+
+        self.interp = nn.functional.interpolate
+        self.scale_factor = scale_factor
+        self.mode = mode
+        self.align_corners = align_corners
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: interpolated data
+        """
+
+        x = self.interp(
+            x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+        )
+
+        return x
+
+
+class ResidualConvUnit(nn.Module):
+    """Residual convolution module.
+    """
+
+    def __init__(self, features):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super().__init__()
+
+        self.conv1 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True
+        )
+
+        self.conv2 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True
+        )
+
+        self.relu = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: output
+        """
+        out = self.relu(x)
+        out = self.conv1(out)
+        out = self.relu(out)
+        out = self.conv2(out)
+
+        return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+    """Feature fusion block.
+    """
+
+    def __init__(self, features):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super(FeatureFusionBlock, self).__init__()
+
+        self.resConfUnit1 = ResidualConvUnit(features)
+        self.resConfUnit2 = ResidualConvUnit(features)
+
+    def forward(self, *xs):
+        """Forward pass.
+
+        Returns:
+            tensor: output
+        """
+        output = xs[0]
+
+        if len(xs) == 2:
+            output += self.resConfUnit1(xs[1])
+
+        output = self.resConfUnit2(output)
+
+        output = nn.functional.interpolate(
+            output, scale_factor=2, mode="bilinear", align_corners=True
+        )
+
+        return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+    """Residual convolution module.
+    """
+
+    def __init__(self, features, activation, bn):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super().__init__()
+
+        self.bn = bn
+
+        self.groups=1
+
+        self.conv1 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+        )
+        
+        self.conv2 = nn.Conv2d(
+            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+        )
+
+        if self.bn==True:
+            self.bn1 = nn.BatchNorm2d(features)
+            self.bn2 = nn.BatchNorm2d(features)
+
+        self.activation = activation
+
+        self.skip_add = nn.quantized.FloatFunctional()
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input
+
+        Returns:
+            tensor: output
+        """
+        
+        out = self.activation(x)
+        out = self.conv1(out)
+        if self.bn==True:
+            out = self.bn1(out)
+       
+        out = self.activation(out)
+        out = self.conv2(out)
+        if self.bn==True:
+            out = self.bn2(out)
+
+        if self.groups > 1:
+            out = self.conv_merge(out)
+
+        return self.skip_add.add(out, x)
+
+        # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+    """Feature fusion block.
+    """
+
+    def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+        """Init.
+
+        Args:
+            features (int): number of features
+        """
+        super(FeatureFusionBlock_custom, self).__init__()
+
+        self.deconv = deconv
+        self.align_corners = align_corners
+
+        self.groups=1
+
+        self.expand = expand
+        out_features = features
+        if self.expand==True:
+            out_features = features//2
+        
+        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+        self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+        self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+        
+        self.skip_add = nn.quantized.FloatFunctional()
+
+    def forward(self, *xs):
+        """Forward pass.
+
+        Returns:
+            tensor: output
+        """
+        output = xs[0]
+
+        if len(xs) == 2:
+            res = self.resConfUnit1(xs[1])
+            output = self.skip_add.add(output, res)
+            # output += res
+
+        output = self.resConfUnit2(output)
+
+        output = nn.functional.interpolate(
+            output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+        )
+
+        output = self.out_conv(output)
+
+        return output
+
diff --git a/ldm/modules/midas/midas/dpt_depth.py b/ldm/modules/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/ldm/modules/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+    FeatureFusionBlock,
+    FeatureFusionBlock_custom,
+    Interpolate,
+    _make_encoder,
+    forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+    return FeatureFusionBlock_custom(
+        features,
+        nn.ReLU(False),
+        deconv=False,
+        bn=use_bn,
+        expand=False,
+        align_corners=True,
+    )
+
+
+class DPT(BaseModel):
+    def __init__(
+        self,
+        head,
+        features=256,
+        backbone="vitb_rn50_384",
+        readout="project",
+        channels_last=False,
+        use_bn=False,
+    ):
+
+        super(DPT, self).__init__()
+
+        self.channels_last = channels_last
+
+        hooks = {
+            "vitb_rn50_384": [0, 1, 8, 11],
+            "vitb16_384": [2, 5, 8, 11],
+            "vitl16_384": [5, 11, 17, 23],
+        }
+
+        # Instantiate backbone and reassemble blocks
+        self.pretrained, self.scratch = _make_encoder(
+            backbone,
+            features,
+            False, # Set to true of you want to train from scratch, uses ImageNet weights
+            groups=1,
+            expand=False,
+            exportable=False,
+            hooks=hooks[backbone],
+            use_readout=readout,
+        )
+
+        self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+        self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+        self.scratch.output_conv = head
+
+
+    def forward(self, x):
+        if self.channels_last == True:
+            x.contiguous(memory_format=torch.channels_last)
+
+        layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+        out = self.scratch.output_conv(path_1)
+
+        return out
+
+
+class DPTDepthModel(DPT):
+    def __init__(self, path=None, non_negative=True, **kwargs):
+        features = kwargs["features"] if "features" in kwargs else 256
+
+        head = nn.Sequential(
+            nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+            Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+            nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(True),
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+            nn.Identity(),
+        )
+
+        super().__init__(head, **kwargs)
+
+        if path is not None:
+           self.load(path)
+
+    def forward(self, x):
+        return super().forward(x).squeeze(dim=1)
+
diff --git a/ldm/modules/midas/midas/midas_net.py b/ldm/modules/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/ldm/modules/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+    """Network for monocular depth estimation.
+    """
+
+    def __init__(self, path=None, features=256, non_negative=True):
+        """Init.
+
+        Args:
+            path (str, optional): Path to saved model. Defaults to None.
+            features (int, optional): Number of features. Defaults to 256.
+            backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+        """
+        print("Loading weights: ", path)
+
+        super(MidasNet, self).__init__()
+
+        use_pretrained = False if path is None else True
+
+        self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+        self.scratch.refinenet4 = FeatureFusionBlock(features)
+        self.scratch.refinenet3 = FeatureFusionBlock(features)
+        self.scratch.refinenet2 = FeatureFusionBlock(features)
+        self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+        self.scratch.output_conv = nn.Sequential(
+            nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+            Interpolate(scale_factor=2, mode="bilinear"),
+            nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+            nn.ReLU(True),
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+        )
+
+        if path:
+            self.load(path)
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input data (image)
+
+        Returns:
+            tensor: depth
+        """
+
+        layer_1 = self.pretrained.layer1(x)
+        layer_2 = self.pretrained.layer2(layer_1)
+        layer_3 = self.pretrained.layer3(layer_2)
+        layer_4 = self.pretrained.layer4(layer_3)
+
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+        out = self.scratch.output_conv(path_1)
+
+        return torch.squeeze(out, dim=1)
diff --git a/ldm/modules/midas/midas/midas_net_custom.py b/ldm/modules/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/ldm/modules/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+    """Network for monocular depth estimation.
+    """
+
+    def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+        blocks={'expand': True}):
+        """Init.
+
+        Args:
+            path (str, optional): Path to saved model. Defaults to None.
+            features (int, optional): Number of features. Defaults to 256.
+            backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+        """
+        print("Loading weights: ", path)
+
+        super(MidasNet_small, self).__init__()
+
+        use_pretrained = False if path else True
+                
+        self.channels_last = channels_last
+        self.blocks = blocks
+        self.backbone = backbone
+
+        self.groups = 1
+
+        features1=features
+        features2=features
+        features3=features
+        features4=features
+        self.expand = False
+        if "expand" in self.blocks and self.blocks['expand'] == True:
+            self.expand = True
+            features1=features
+            features2=features*2
+            features3=features*4
+            features4=features*8
+
+        self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+  
+        self.scratch.activation = nn.ReLU(False)    
+
+        self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+        self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+        
+        self.scratch.output_conv = nn.Sequential(
+            nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+            Interpolate(scale_factor=2, mode="bilinear"),
+            nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+            self.scratch.activation,
+            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+            nn.ReLU(True) if non_negative else nn.Identity(),
+            nn.Identity(),
+        )
+        
+        if path:
+            self.load(path)
+
+
+    def forward(self, x):
+        """Forward pass.
+
+        Args:
+            x (tensor): input data (image)
+
+        Returns:
+            tensor: depth
+        """
+        if self.channels_last==True:
+            print("self.channels_last = ", self.channels_last)
+            x.contiguous(memory_format=torch.channels_last)
+
+
+        layer_1 = self.pretrained.layer1(x)
+        layer_2 = self.pretrained.layer2(layer_1)
+        layer_3 = self.pretrained.layer3(layer_2)
+        layer_4 = self.pretrained.layer4(layer_3)
+        
+        layer_1_rn = self.scratch.layer1_rn(layer_1)
+        layer_2_rn = self.scratch.layer2_rn(layer_2)
+        layer_3_rn = self.scratch.layer3_rn(layer_3)
+        layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+        path_4 = self.scratch.refinenet4(layer_4_rn)
+        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+        
+        out = self.scratch.output_conv(path_1)
+
+        return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+    prev_previous_type = nn.Identity()
+    prev_previous_name = ''
+    previous_type = nn.Identity()
+    previous_name = ''
+    for name, module in m.named_modules():
+        if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+            # print("FUSED ", prev_previous_name, previous_name, name)
+            torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+        elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+            # print("FUSED ", prev_previous_name, previous_name)
+            torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+        # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+        #    print("FUSED ", previous_name, name)
+        #    torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+        prev_previous_type = previous_type
+        prev_previous_name = previous_name
+        previous_type = type(module)
+        previous_name = name
\ No newline at end of file
diff --git a/ldm/modules/midas/midas/transforms.py b/ldm/modules/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/ldm/modules/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+    """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+    Args:
+        sample (dict): sample
+        size (tuple): image size
+
+    Returns:
+        tuple: new size
+    """
+    shape = list(sample["disparity"].shape)
+
+    if shape[0] >= size[0] and shape[1] >= size[1]:
+        return sample
+
+    scale = [0, 0]
+    scale[0] = size[0] / shape[0]
+    scale[1] = size[1] / shape[1]
+
+    scale = max(scale)
+
+    shape[0] = math.ceil(scale * shape[0])
+    shape[1] = math.ceil(scale * shape[1])
+
+    # resize
+    sample["image"] = cv2.resize(
+        sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+    )
+
+    sample["disparity"] = cv2.resize(
+        sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+    )
+    sample["mask"] = cv2.resize(
+        sample["mask"].astype(np.float32),
+        tuple(shape[::-1]),
+        interpolation=cv2.INTER_NEAREST,
+    )
+    sample["mask"] = sample["mask"].astype(bool)
+
+    return tuple(shape)
+
+
+class Resize(object):
+    """Resize sample to given size (width, height).
+    """
+
+    def __init__(
+        self,
+        width,
+        height,
+        resize_target=True,
+        keep_aspect_ratio=False,
+        ensure_multiple_of=1,
+        resize_method="lower_bound",
+        image_interpolation_method=cv2.INTER_AREA,
+    ):
+        """Init.
+
+        Args:
+            width (int): desired output width
+            height (int): desired output height
+            resize_target (bool, optional):
+                True: Resize the full sample (image, mask, target).
+                False: Resize image only.
+                Defaults to True.
+            keep_aspect_ratio (bool, optional):
+                True: Keep the aspect ratio of the input sample.
+                Output sample might not have the given width and height, and
+                resize behaviour depends on the parameter 'resize_method'.
+                Defaults to False.
+            ensure_multiple_of (int, optional):
+                Output width and height is constrained to be multiple of this parameter.
+                Defaults to 1.
+            resize_method (str, optional):
+                "lower_bound": Output will be at least as large as the given size.
+                "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+                "minimal": Scale as least as possible.  (Output size might be smaller than given size.)
+                Defaults to "lower_bound".
+        """
+        self.__width = width
+        self.__height = height
+
+        self.__resize_target = resize_target
+        self.__keep_aspect_ratio = keep_aspect_ratio
+        self.__multiple_of = ensure_multiple_of
+        self.__resize_method = resize_method
+        self.__image_interpolation_method = image_interpolation_method
+
+    def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+        y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        if max_val is not None and y > max_val:
+            y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        if y < min_val:
+            y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+        return y
+
+    def get_size(self, width, height):
+        # determine new height and width
+        scale_height = self.__height / height
+        scale_width = self.__width / width
+
+        if self.__keep_aspect_ratio:
+            if self.__resize_method == "lower_bound":
+                # scale such that output size is lower bound
+                if scale_width > scale_height:
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            elif self.__resize_method == "upper_bound":
+                # scale such that output size is upper bound
+                if scale_width < scale_height:
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            elif self.__resize_method == "minimal":
+                # scale as least as possbile
+                if abs(1 - scale_width) < abs(1 - scale_height):
+                    # fit width
+                    scale_height = scale_width
+                else:
+                    # fit height
+                    scale_width = scale_height
+            else:
+                raise ValueError(
+                    f"resize_method {self.__resize_method} not implemented"
+                )
+
+        if self.__resize_method == "lower_bound":
+            new_height = self.constrain_to_multiple_of(
+                scale_height * height, min_val=self.__height
+            )
+            new_width = self.constrain_to_multiple_of(
+                scale_width * width, min_val=self.__width
+            )
+        elif self.__resize_method == "upper_bound":
+            new_height = self.constrain_to_multiple_of(
+                scale_height * height, max_val=self.__height
+            )
+            new_width = self.constrain_to_multiple_of(
+                scale_width * width, max_val=self.__width
+            )
+        elif self.__resize_method == "minimal":
+            new_height = self.constrain_to_multiple_of(scale_height * height)
+            new_width = self.constrain_to_multiple_of(scale_width * width)
+        else:
+            raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+        return (new_width, new_height)
+
+    def __call__(self, sample):
+        width, height = self.get_size(
+            sample["image"].shape[1], sample["image"].shape[0]
+        )
+
+        # resize sample
+        sample["image"] = cv2.resize(
+            sample["image"],
+            (width, height),
+            interpolation=self.__image_interpolation_method,
+        )
+
+        if self.__resize_target:
+            if "disparity" in sample:
+                sample["disparity"] = cv2.resize(
+                    sample["disparity"],
+                    (width, height),
+                    interpolation=cv2.INTER_NEAREST,
+                )
+
+            if "depth" in sample:
+                sample["depth"] = cv2.resize(
+                    sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+                )
+
+            sample["mask"] = cv2.resize(
+                sample["mask"].astype(np.float32),
+                (width, height),
+                interpolation=cv2.INTER_NEAREST,
+            )
+            sample["mask"] = sample["mask"].astype(bool)
+
+        return sample
+
+
+class NormalizeImage(object):
+    """Normlize image by given mean and std.
+    """
+
+    def __init__(self, mean, std):
+        self.__mean = mean
+        self.__std = std
+
+    def __call__(self, sample):
+        sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+        return sample
+
+
+class PrepareForNet(object):
+    """Prepare sample for usage as network input.
+    """
+
+    def __init__(self):
+        pass
+
+    def __call__(self, sample):
+        image = np.transpose(sample["image"], (2, 0, 1))
+        sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+        if "mask" in sample:
+            sample["mask"] = sample["mask"].astype(np.float32)
+            sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+        if "disparity" in sample:
+            disparity = sample["disparity"].astype(np.float32)
+            sample["disparity"] = np.ascontiguousarray(disparity)
+
+        if "depth" in sample:
+            depth = sample["depth"].astype(np.float32)
+            sample["depth"] = np.ascontiguousarray(depth)
+
+        return sample
diff --git a/ldm/modules/midas/midas/vit.py b/ldm/modules/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/ldm/modules/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+    def __init__(self, start_index=1):
+        super(Slice, self).__init__()
+        self.start_index = start_index
+
+    def forward(self, x):
+        return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+    def __init__(self, start_index=1):
+        super(AddReadout, self).__init__()
+        self.start_index = start_index
+
+    def forward(self, x):
+        if self.start_index == 2:
+            readout = (x[:, 0] + x[:, 1]) / 2
+        else:
+            readout = x[:, 0]
+        return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+    def __init__(self, in_features, start_index=1):
+        super(ProjectReadout, self).__init__()
+        self.start_index = start_index
+
+        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+    def forward(self, x):
+        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+        features = torch.cat((x[:, self.start_index :], readout), -1)
+
+        return self.project(features)
+
+
+class Transpose(nn.Module):
+    def __init__(self, dim0, dim1):
+        super(Transpose, self).__init__()
+        self.dim0 = dim0
+        self.dim1 = dim1
+
+    def forward(self, x):
+        x = x.transpose(self.dim0, self.dim1)
+        return x
+
+
+def forward_vit(pretrained, x):
+    b, c, h, w = x.shape
+
+    glob = pretrained.model.forward_flex(x)
+
+    layer_1 = pretrained.activations["1"]
+    layer_2 = pretrained.activations["2"]
+    layer_3 = pretrained.activations["3"]
+    layer_4 = pretrained.activations["4"]
+
+    layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+    layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+    layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+    layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+    unflatten = nn.Sequential(
+        nn.Unflatten(
+            2,
+            torch.Size(
+                [
+                    h // pretrained.model.patch_size[1],
+                    w // pretrained.model.patch_size[0],
+                ]
+            ),
+        )
+    )
+
+    if layer_1.ndim == 3:
+        layer_1 = unflatten(layer_1)
+    if layer_2.ndim == 3:
+        layer_2 = unflatten(layer_2)
+    if layer_3.ndim == 3:
+        layer_3 = unflatten(layer_3)
+    if layer_4.ndim == 3:
+        layer_4 = unflatten(layer_4)
+
+    layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+    layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+    layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+    layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+    return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+    posemb_tok, posemb_grid = (
+        posemb[:, : self.start_index],
+        posemb[0, self.start_index :],
+    )
+
+    gs_old = int(math.sqrt(len(posemb_grid)))
+
+    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+    return posemb
+
+
+def forward_flex(self, x):
+    b, c, h, w = x.shape
+
+    pos_embed = self._resize_pos_embed(
+        self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+    )
+
+    B = x.shape[0]
+
+    if hasattr(self.patch_embed, "backbone"):
+        x = self.patch_embed.backbone(x)
+        if isinstance(x, (list, tuple)):
+            x = x[-1]  # last feature if backbone outputs list/tuple of features
+
+    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+    if getattr(self, "dist_token", None) is not None:
+        cls_tokens = self.cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        dist_token = self.dist_token.expand(B, -1, -1)
+        x = torch.cat((cls_tokens, dist_token, x), dim=1)
+    else:
+        cls_tokens = self.cls_token.expand(
+            B, -1, -1
+        )  # stole cls_tokens impl from Phil Wang, thanks
+        x = torch.cat((cls_tokens, x), dim=1)
+
+    x = x + pos_embed
+    x = self.pos_drop(x)
+
+    for blk in self.blocks:
+        x = blk(x)
+
+    x = self.norm(x)
+
+    return x
+
+
+activations = {}
+
+
+def get_activation(name):
+    def hook(model, input, output):
+        activations[name] = output
+
+    return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+    if use_readout == "ignore":
+        readout_oper = [Slice(start_index)] * len(features)
+    elif use_readout == "add":
+        readout_oper = [AddReadout(start_index)] * len(features)
+    elif use_readout == "project":
+        readout_oper = [
+            ProjectReadout(vit_features, start_index) for out_feat in features
+        ]
+    else:
+        assert (
+            False
+        ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+    return readout_oper
+
+
+def _make_vit_b16_backbone(
+    model,
+    features=[96, 192, 384, 768],
+    size=[384, 384],
+    hooks=[2, 5, 8, 11],
+    vit_features=768,
+    use_readout="ignore",
+    start_index=1,
+):
+    pretrained = nn.Module()
+
+    pretrained.model = model
+    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+    pretrained.activations = activations
+
+    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+    # 32, 48, 136, 384
+    pretrained.act_postprocess1 = nn.Sequential(
+        readout_oper[0],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[0],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.ConvTranspose2d(
+            in_channels=features[0],
+            out_channels=features[0],
+            kernel_size=4,
+            stride=4,
+            padding=0,
+            bias=True,
+            dilation=1,
+            groups=1,
+        ),
+    )
+
+    pretrained.act_postprocess2 = nn.Sequential(
+        readout_oper[1],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[1],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.ConvTranspose2d(
+            in_channels=features[1],
+            out_channels=features[1],
+            kernel_size=2,
+            stride=2,
+            padding=0,
+            bias=True,
+            dilation=1,
+            groups=1,
+        ),
+    )
+
+    pretrained.act_postprocess3 = nn.Sequential(
+        readout_oper[2],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[2],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+    )
+
+    pretrained.act_postprocess4 = nn.Sequential(
+        readout_oper[3],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[3],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.Conv2d(
+            in_channels=features[3],
+            out_channels=features[3],
+            kernel_size=3,
+            stride=2,
+            padding=1,
+        ),
+    )
+
+    pretrained.model.start_index = start_index
+    pretrained.model.patch_size = [16, 16]
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+    pretrained.model._resize_pos_embed = types.MethodType(
+        _resize_pos_embed, pretrained.model
+    )
+
+    return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+    hooks = [5, 11, 17, 23] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model,
+        features=[256, 512, 1024, 1024],
+        hooks=hooks,
+        vit_features=1024,
+        use_readout=use_readout,
+    )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+    )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+    )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+    model = timm.create_model(
+        "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+    )
+
+    hooks = [2, 5, 8, 11] if hooks == None else hooks
+    return _make_vit_b16_backbone(
+        model,
+        features=[96, 192, 384, 768],
+        hooks=hooks,
+        use_readout=use_readout,
+        start_index=2,
+    )
+
+
+def _make_vit_b_rn50_backbone(
+    model,
+    features=[256, 512, 768, 768],
+    size=[384, 384],
+    hooks=[0, 1, 8, 11],
+    vit_features=768,
+    use_vit_only=False,
+    use_readout="ignore",
+    start_index=1,
+):
+    pretrained = nn.Module()
+
+    pretrained.model = model
+
+    if use_vit_only == True:
+        pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+        pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+    else:
+        pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+            get_activation("1")
+        )
+        pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+            get_activation("2")
+        )
+
+    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+    pretrained.activations = activations
+
+    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+    if use_vit_only == True:
+        pretrained.act_postprocess1 = nn.Sequential(
+            readout_oper[0],
+            Transpose(1, 2),
+            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+            nn.Conv2d(
+                in_channels=vit_features,
+                out_channels=features[0],
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            ),
+            nn.ConvTranspose2d(
+                in_channels=features[0],
+                out_channels=features[0],
+                kernel_size=4,
+                stride=4,
+                padding=0,
+                bias=True,
+                dilation=1,
+                groups=1,
+            ),
+        )
+
+        pretrained.act_postprocess2 = nn.Sequential(
+            readout_oper[1],
+            Transpose(1, 2),
+            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+            nn.Conv2d(
+                in_channels=vit_features,
+                out_channels=features[1],
+                kernel_size=1,
+                stride=1,
+                padding=0,
+            ),
+            nn.ConvTranspose2d(
+                in_channels=features[1],
+                out_channels=features[1],
+                kernel_size=2,
+                stride=2,
+                padding=0,
+                bias=True,
+                dilation=1,
+                groups=1,
+            ),
+        )
+    else:
+        pretrained.act_postprocess1 = nn.Sequential(
+            nn.Identity(), nn.Identity(), nn.Identity()
+        )
+        pretrained.act_postprocess2 = nn.Sequential(
+            nn.Identity(), nn.Identity(), nn.Identity()
+        )
+
+    pretrained.act_postprocess3 = nn.Sequential(
+        readout_oper[2],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[2],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+    )
+
+    pretrained.act_postprocess4 = nn.Sequential(
+        readout_oper[3],
+        Transpose(1, 2),
+        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+        nn.Conv2d(
+            in_channels=vit_features,
+            out_channels=features[3],
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        ),
+        nn.Conv2d(
+            in_channels=features[3],
+            out_channels=features[3],
+            kernel_size=3,
+            stride=2,
+            padding=1,
+        ),
+    )
+
+    pretrained.model.start_index = start_index
+    pretrained.model.patch_size = [16, 16]
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+    # We inject this function into the VisionTransformer instances so that
+    # we can use it with interpolated position embeddings without modifying the library source.
+    pretrained.model._resize_pos_embed = types.MethodType(
+        _resize_pos_embed, pretrained.model
+    )
+
+    return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+    pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+    model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+    hooks = [0, 1, 8, 11] if hooks == None else hooks
+    return _make_vit_b_rn50_backbone(
+        model,
+        features=[256, 512, 768, 768],
+        size=[384, 384],
+        hooks=hooks,
+        use_vit_only=use_vit_only,
+        use_readout=use_readout,
+    )
diff --git a/ldm/modules/midas/utils.py b/ldm/modules/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/ldm/modules/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+    """Read pfm file.
+
+    Args:
+        path (str): path to file
+
+    Returns:
+        tuple: (data, scale)
+    """
+    with open(path, "rb") as file:
+
+        color = None
+        width = None
+        height = None
+        scale = None
+        endian = None
+
+        header = file.readline().rstrip()
+        if header.decode("ascii") == "PF":
+            color = True
+        elif header.decode("ascii") == "Pf":
+            color = False
+        else:
+            raise Exception("Not a PFM file: " + path)
+
+        dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+        if dim_match:
+            width, height = list(map(int, dim_match.groups()))
+        else:
+            raise Exception("Malformed PFM header.")
+
+        scale = float(file.readline().decode("ascii").rstrip())
+        if scale < 0:
+            # little-endian
+            endian = "<"
+            scale = -scale
+        else:
+            # big-endian
+            endian = ">"
+
+        data = np.fromfile(file, endian + "f")
+        shape = (height, width, 3) if color else (height, width)
+
+        data = np.reshape(data, shape)
+        data = np.flipud(data)
+
+        return data, scale
+
+
+def write_pfm(path, image, scale=1):
+    """Write pfm file.
+
+    Args:
+        path (str): pathto file
+        image (array): data
+        scale (int, optional): Scale. Defaults to 1.
+    """
+
+    with open(path, "wb") as file:
+        color = None
+
+        if image.dtype.name != "float32":
+            raise Exception("Image dtype must be float32.")
+
+        image = np.flipud(image)
+
+        if len(image.shape) == 3 and image.shape[2] == 3:  # color image
+            color = True
+        elif (
+            len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+        ):  # greyscale
+            color = False
+        else:
+            raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+        file.write("PF\n" if color else "Pf\n".encode())
+        file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+        endian = image.dtype.byteorder
+
+        if endian == "<" or endian == "=" and sys.byteorder == "little":
+            scale = -scale
+
+        file.write("%f\n".encode() % scale)
+
+        image.tofile(file)
+
+
+def read_image(path):
+    """Read image and output RGB image (0-1).
+
+    Args:
+        path (str): path to file
+
+    Returns:
+        array: RGB image (0-1)
+    """
+    img = cv2.imread(path)
+
+    if img.ndim == 2:
+        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+    return img
+
+
+def resize_image(img):
+    """Resize image and make it fit for network.
+
+    Args:
+        img (array): image
+
+    Returns:
+        tensor: data ready for network
+    """
+    height_orig = img.shape[0]
+    width_orig = img.shape[1]
+
+    if width_orig > height_orig:
+        scale = width_orig / 384
+    else:
+        scale = height_orig / 384
+
+    height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+    width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+    img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+    img_resized = (
+        torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+    )
+    img_resized = img_resized.unsqueeze(0)
+
+    return img_resized
+
+
+def resize_depth(depth, width, height):
+    """Resize depth map and bring to CPU (numpy).
+
+    Args:
+        depth (tensor): depth
+        width (int): image width
+        height (int): image height
+
+    Returns:
+        array: processed depth
+    """
+    depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+    depth_resized = cv2.resize(
+        depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+    )
+
+    return depth_resized
+
+def write_depth(path, depth, bits=1):
+    """Write depth map to pfm and png file.
+
+    Args:
+        path (str): filepath without extension
+        depth (array): depth
+    """
+    write_pfm(path + ".pfm", depth.astype(np.float32))
+
+    depth_min = depth.min()
+    depth_max = depth.max()
+
+    max_val = (2**(8*bits))-1
+
+    if depth_max - depth_min > np.finfo("float").eps:
+        out = max_val * (depth - depth_min) / (depth_max - depth_min)
+    else:
+        out = np.zeros(depth.shape, dtype=depth.type)
+
+    if bits == 1:
+        cv2.imwrite(path + ".png", out.astype("uint8"))
+    elif bits == 2:
+        cv2.imwrite(path + ".png", out.astype("uint16"))
+
+    return
diff --git a/ldm/util.py b/ldm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..45cb050ece6f401a22dde098ce3f1ff663c5eb6a
--- /dev/null
+++ b/ldm/util.py
@@ -0,0 +1,197 @@
+import importlib
+
+import torch
+from torch import optim
+import numpy as np
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+    # wh a tuple of (width, height)
+    # xc a list of captions to plot
+    b = len(xc)
+    txts = list()
+    for bi in range(b):
+        txt = Image.new("RGB", wh, color="white")
+        draw = ImageDraw.Draw(txt)
+        font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
+        nc = int(40 * (wh[0] / 256))
+        lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+        try:
+            draw.text((0, 0), lines, fill="black", font=font)
+        except UnicodeEncodeError:
+            print("Cant encode string for logging. Skipping.")
+
+        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+        txts.append(txt)
+    txts = np.stack(txts)
+    txts = torch.tensor(txts)
+    return txts
+
+
+def ismap(x):
+    if not isinstance(x, torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+    if not isinstance(x,torch.Tensor):
+        return False
+    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+    return x is not None
+
+
+def default(val, d):
+    if exists(val):
+        return val
+    return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+    """
+    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+    Take the mean over all non-batch dimensions.
+    """
+    return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+    total_params = sum(p.numel() for p in model.parameters())
+    if verbose:
+        print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+    return total_params
+
+
+def instantiate_from_config(config):
+    if not "target" in config:
+        if config == '__is_first_stage__':
+            return None
+        elif config == "__is_unconditional__":
+            return None
+        raise KeyError("Expected key `target` to instantiate.")
+    return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+    module, cls = string.rsplit(".", 1)
+    if reload:
+        module_imp = importlib.import_module(module)
+        importlib.reload(module_imp)
+    return getattr(importlib.import_module(module, package=None), cls)
+
+
+class AdamWwithEMAandWings(optim.Optimizer):
+    # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+    def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8,  # TODO: check hyperparameters before using
+                 weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999,   # ema decay to match previous code
+                 ema_power=1., param_names=()):
+        """AdamW that saves EMA versions of the parameters."""
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        if not 0.0 <= weight_decay:
+            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+        if not 0.0 <= ema_decay <= 1.0:
+            raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+                        ema_power=ema_power, param_names=param_names)
+        super().__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        for group in self.param_groups:
+            group.setdefault('amsgrad', False)
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        """Performs a single optimization step.
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        for group in self.param_groups:
+            params_with_grad = []
+            grads = []
+            exp_avgs = []
+            exp_avg_sqs = []
+            ema_params_with_grad = []
+            state_sums = []
+            max_exp_avg_sqs = []
+            state_steps = []
+            amsgrad = group['amsgrad']
+            beta1, beta2 = group['betas']
+            ema_decay = group['ema_decay']
+            ema_power = group['ema_power']
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                params_with_grad.append(p)
+                if p.grad.is_sparse:
+                    raise RuntimeError('AdamW does not support sparse gradients')
+                grads.append(p.grad)
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state['step'] = 0
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    if amsgrad:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+                    # Exponential moving average of parameter values
+                    state['param_exp_avg'] = p.detach().float().clone()
+
+                exp_avgs.append(state['exp_avg'])
+                exp_avg_sqs.append(state['exp_avg_sq'])
+                ema_params_with_grad.append(state['param_exp_avg'])
+
+                if amsgrad:
+                    max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+
+                # update the steps for each param group update
+                state['step'] += 1
+                # record the step after step update
+                state_steps.append(state['step'])
+
+            optim._functional.adamw(params_with_grad,
+                    grads,
+                    exp_avgs,
+                    exp_avg_sqs,
+                    max_exp_avg_sqs,
+                    state_steps,
+                    amsgrad=amsgrad,
+                    beta1=beta1,
+                    beta2=beta2,
+                    lr=group['lr'],
+                    weight_decay=group['weight_decay'],
+                    eps=group['eps'],
+                    maximize=False)
+
+            cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+            for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+                ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+
+        return loss
\ No newline at end of file
diff --git a/models/cldm_v15.yaml b/models/cldm_v15.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fde1825577acd46dc90d8d7c6730e22be762fccb
--- /dev/null
+++ b/models/cldm_v15.yaml
@@ -0,0 +1,79 @@
+model:
+  target: cldm.cldm.ControlLDM
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "jpg"
+    cond_stage_key: "txt"
+    control_key: "hint"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: false
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+    only_mid_control: False
+
+    control_stage_config:
+      target: cldm.cldm.ControlNet
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        hint_channels: 3
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    unet_config:
+      target: cldm.cldm.ControlledUnetModel
+      params:
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_heads: 8
+        use_spatial_transformer: True
+        transformer_depth: 1
+        context_dim: 768
+        use_checkpoint: True
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
diff --git a/models/cldm_v21.yaml b/models/cldm_v21.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fc65193647e476e108fce5977f11250d55919106
--- /dev/null
+++ b/models/cldm_v21.yaml
@@ -0,0 +1,85 @@
+model:
+  target: cldm.cldm.ControlLDM
+  params:
+    linear_start: 0.00085
+    linear_end: 0.0120
+    num_timesteps_cond: 1
+    log_every_t: 200
+    timesteps: 1000
+    first_stage_key: "jpg"
+    cond_stage_key: "txt"
+    control_key: "hint"
+    image_size: 64
+    channels: 4
+    cond_stage_trainable: false
+    conditioning_key: crossattn
+    monitor: val/loss_simple_ema
+    scale_factor: 0.18215
+    use_ema: False
+    only_mid_control: False
+
+    control_stage_config:
+      target: cldm.cldm.ControlNet
+      params:
+        use_checkpoint: True
+        image_size: 32 # unused
+        in_channels: 4
+        hint_channels: 3
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64 # need to fix for flash-attn
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        legacy: False
+
+    unet_config:
+      target: cldm.cldm.ControlledUnetModel
+      params:
+        use_checkpoint: True
+        image_size: 32 # unused
+        in_channels: 4
+        out_channels: 4
+        model_channels: 320
+        attention_resolutions: [ 4, 2, 1 ]
+        num_res_blocks: 2
+        channel_mult: [ 1, 2, 4, 4 ]
+        num_head_channels: 64 # need to fix for flash-attn
+        use_spatial_transformer: True
+        use_linear_in_transformer: True
+        transformer_depth: 1
+        context_dim: 1024
+        legacy: False
+
+    first_stage_config:
+      target: ldm.models.autoencoder.AutoencoderKL
+      params:
+        embed_dim: 4
+        monitor: val/rec_loss
+        ddconfig:
+          #attn_type: "vanilla-xformers"
+          double_z: true
+          z_channels: 4
+          resolution: 256
+          in_channels: 3
+          out_ch: 3
+          ch: 128
+          ch_mult:
+          - 1
+          - 2
+          - 4
+          - 4
+          num_res_blocks: 2
+          attn_resolutions: []
+          dropout: 0.0
+        lossconfig:
+          target: torch.nn.Identity
+
+    cond_stage_config:
+      target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+      params:
+        freeze: True
+        layer: "penultimate"