|
|
|
|
|
from os import path as osp
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn.modules.utils import _pair
|
|
from torch.onnx.operators import shape_as_tensor
|
|
|
|
|
|
def bilinear_grid_sample(im, grid, align_corners=False):
|
|
"""Given an input and a flow-field grid, computes the output using input
|
|
values and pixel locations from grid. Supported only bilinear interpolation
|
|
method to sample the input pixels.
|
|
|
|
Args:
|
|
im (torch.Tensor): Input feature map, shape (N, C, H, W)
|
|
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
|
|
align_corners {bool}: If set to True, the extrema (-1 and 1) are
|
|
considered as referring to the center points of the input’s
|
|
corner pixels. If set to False, they are instead considered as
|
|
referring to the corner points of the input’s corner pixels,
|
|
making the sampling more resolution agnostic.
|
|
Returns:
|
|
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
|
|
"""
|
|
n, c, h, w = im.shape
|
|
gn, gh, gw, _ = grid.shape
|
|
assert n == gn
|
|
|
|
x = grid[:, :, :, 0]
|
|
y = grid[:, :, :, 1]
|
|
|
|
if align_corners:
|
|
x = ((x + 1) / 2) * (w - 1)
|
|
y = ((y + 1) / 2) * (h - 1)
|
|
else:
|
|
x = ((x + 1) * w - 1) / 2
|
|
y = ((y + 1) * h - 1) / 2
|
|
|
|
x = x.view(n, -1)
|
|
y = y.view(n, -1)
|
|
|
|
x0 = torch.floor(x).long()
|
|
y0 = torch.floor(y).long()
|
|
x1 = x0 + 1
|
|
y1 = y0 + 1
|
|
|
|
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
|
|
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
|
|
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
|
|
wd = ((x - x0) * (y - y0)).unsqueeze(1)
|
|
|
|
|
|
im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
|
|
padded_h = h + 2
|
|
padded_w = w + 2
|
|
|
|
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
|
|
|
|
|
|
x0 = torch.where(x0 < 0, torch.tensor(0), x0)
|
|
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0)
|
|
x1 = torch.where(x1 < 0, torch.tensor(0), x1)
|
|
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1)
|
|
y0 = torch.where(y0 < 0, torch.tensor(0), y0)
|
|
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0)
|
|
y1 = torch.where(y1 < 0, torch.tensor(0), y1)
|
|
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1)
|
|
|
|
im_padded = im_padded.view(n, c, -1)
|
|
|
|
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
|
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
|
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
|
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
|
|
|
Ia = torch.gather(im_padded, 2, x0_y0)
|
|
Ib = torch.gather(im_padded, 2, x0_y1)
|
|
Ic = torch.gather(im_padded, 2, x1_y0)
|
|
Id = torch.gather(im_padded, 2, x1_y1)
|
|
|
|
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
|
|
|
|
|
|
def is_in_onnx_export_without_custom_ops():
|
|
from annotator.uniformer.mmcv.ops import get_onnxruntime_op_path
|
|
ort_custom_op_path = get_onnxruntime_op_path()
|
|
return torch.onnx.is_in_onnx_export(
|
|
) and not osp.exists(ort_custom_op_path)
|
|
|
|
|
|
def normalize(grid):
|
|
"""Normalize input grid from [-1, 1] to [0, 1]
|
|
Args:
|
|
grid (Tensor): The grid to be normalize, range [-1, 1].
|
|
Returns:
|
|
Tensor: Normalized grid, range [0, 1].
|
|
"""
|
|
|
|
return (grid + 1.0) / 2.0
|
|
|
|
|
|
def denormalize(grid):
|
|
"""Denormalize input grid from range [0, 1] to [-1, 1]
|
|
Args:
|
|
grid (Tensor): The grid to be denormalize, range [0, 1].
|
|
Returns:
|
|
Tensor: Denormalized grid, range [-1, 1].
|
|
"""
|
|
|
|
return grid * 2.0 - 1.0
|
|
|
|
|
|
def generate_grid(num_grid, size, device):
|
|
"""Generate regular square grid of points in [0, 1] x [0, 1] coordinate
|
|
space.
|
|
|
|
Args:
|
|
num_grid (int): The number of grids to sample, one for each region.
|
|
size (tuple(int, int)): The side size of the regular grid.
|
|
device (torch.device): Desired device of returned tensor.
|
|
|
|
Returns:
|
|
(torch.Tensor): A tensor of shape (num_grid, size[0]*size[1], 2) that
|
|
contains coordinates for the regular grids.
|
|
"""
|
|
|
|
affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device)
|
|
grid = F.affine_grid(
|
|
affine_trans, torch.Size((1, 1, *size)), align_corners=False)
|
|
grid = normalize(grid)
|
|
return grid.view(1, -1, 2).expand(num_grid, -1, -1)
|
|
|
|
|
|
def rel_roi_point_to_abs_img_point(rois, rel_roi_points):
|
|
"""Convert roi based relative point coordinates to image based absolute
|
|
point coordinates.
|
|
|
|
Args:
|
|
rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
|
|
rel_roi_points (Tensor): Point coordinates inside RoI, relative to
|
|
RoI, location, range (0, 1), shape (N, P, 2)
|
|
Returns:
|
|
Tensor: Image based absolute point coordinates, shape (N, P, 2)
|
|
"""
|
|
|
|
with torch.no_grad():
|
|
assert rel_roi_points.size(0) == rois.size(0)
|
|
assert rois.dim() == 2
|
|
assert rel_roi_points.dim() == 3
|
|
assert rel_roi_points.size(2) == 2
|
|
|
|
if rois.size(1) == 5:
|
|
rois = rois[:, 1:]
|
|
abs_img_points = rel_roi_points.clone()
|
|
|
|
|
|
xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0])
|
|
ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1])
|
|
xs += rois[:, None, 0]
|
|
ys += rois[:, None, 1]
|
|
abs_img_points = torch.stack([xs, ys], dim=2)
|
|
return abs_img_points
|
|
|
|
|
|
def get_shape_from_feature_map(x):
|
|
"""Get spatial resolution of input feature map considering exporting to
|
|
onnx mode.
|
|
|
|
Args:
|
|
x (torch.Tensor): Input tensor, shape (N, C, H, W)
|
|
Returns:
|
|
torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2)
|
|
"""
|
|
if torch.onnx.is_in_onnx_export():
|
|
img_shape = shape_as_tensor(x)[2:].flip(0).view(1, 1, 2).to(
|
|
x.device).float()
|
|
else:
|
|
img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, 2).to(
|
|
x.device).float()
|
|
return img_shape
|
|
|
|
|
|
def abs_img_point_to_rel_img_point(abs_img_points, img, spatial_scale=1.):
|
|
"""Convert image based absolute point coordinates to image based relative
|
|
coordinates for sampling.
|
|
|
|
Args:
|
|
abs_img_points (Tensor): Image based absolute point coordinates,
|
|
shape (N, P, 2)
|
|
img (tuple/Tensor): (height, width) of image or feature map.
|
|
spatial_scale (float): Scale points by this factor. Default: 1.
|
|
|
|
Returns:
|
|
Tensor: Image based relative point coordinates for sampling,
|
|
shape (N, P, 2)
|
|
"""
|
|
|
|
assert (isinstance(img, tuple) and len(img) == 2) or \
|
|
(isinstance(img, torch.Tensor) and len(img.shape) == 4)
|
|
|
|
if isinstance(img, tuple):
|
|
h, w = img
|
|
scale = torch.tensor([w, h],
|
|
dtype=torch.float,
|
|
device=abs_img_points.device)
|
|
scale = scale.view(1, 1, 2)
|
|
else:
|
|
scale = get_shape_from_feature_map(img)
|
|
|
|
return abs_img_points / scale * spatial_scale
|
|
|
|
|
|
def rel_roi_point_to_rel_img_point(rois,
|
|
rel_roi_points,
|
|
img,
|
|
spatial_scale=1.):
|
|
"""Convert roi based relative point coordinates to image based absolute
|
|
point coordinates.
|
|
|
|
Args:
|
|
rois (Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5)
|
|
rel_roi_points (Tensor): Point coordinates inside RoI, relative to
|
|
RoI, location, range (0, 1), shape (N, P, 2)
|
|
img (tuple/Tensor): (height, width) of image or feature map.
|
|
spatial_scale (float): Scale points by this factor. Default: 1.
|
|
|
|
Returns:
|
|
Tensor: Image based relative point coordinates for sampling,
|
|
shape (N, P, 2)
|
|
"""
|
|
|
|
abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points)
|
|
rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img,
|
|
spatial_scale)
|
|
|
|
return rel_img_point
|
|
|
|
|
|
def point_sample(input, points, align_corners=False, **kwargs):
|
|
"""A wrapper around :func:`grid_sample` to support 3D point_coords tensors
|
|
Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to
|
|
lie inside ``[0, 1] x [0, 1]`` square.
|
|
|
|
Args:
|
|
input (Tensor): Feature map, shape (N, C, H, W).
|
|
points (Tensor): Image based absolute point coordinates (normalized),
|
|
range [0, 1] x [0, 1], shape (N, P, 2) or (N, Hgrid, Wgrid, 2).
|
|
align_corners (bool): Whether align_corners. Default: False
|
|
|
|
Returns:
|
|
Tensor: Features of `point` on `input`, shape (N, C, P) or
|
|
(N, C, Hgrid, Wgrid).
|
|
"""
|
|
|
|
add_dim = False
|
|
if points.dim() == 3:
|
|
add_dim = True
|
|
points = points.unsqueeze(2)
|
|
if is_in_onnx_export_without_custom_ops():
|
|
|
|
|
|
|
|
output = bilinear_grid_sample(
|
|
input, denormalize(points), align_corners=align_corners)
|
|
else:
|
|
output = F.grid_sample(
|
|
input, denormalize(points), align_corners=align_corners, **kwargs)
|
|
if add_dim:
|
|
output = output.squeeze(3)
|
|
return output
|
|
|
|
|
|
class SimpleRoIAlign(nn.Module):
|
|
|
|
def __init__(self, output_size, spatial_scale, aligned=True):
|
|
"""Simple RoI align in PointRend, faster than standard RoIAlign.
|
|
|
|
Args:
|
|
output_size (tuple[int]): h, w
|
|
spatial_scale (float): scale the input boxes by this number
|
|
aligned (bool): if False, use the legacy implementation in
|
|
MMDetection, align_corners=True will be used in F.grid_sample.
|
|
If True, align the results more perfectly.
|
|
"""
|
|
|
|
super(SimpleRoIAlign, self).__init__()
|
|
self.output_size = _pair(output_size)
|
|
self.spatial_scale = float(spatial_scale)
|
|
|
|
self.use_torchvision = False
|
|
self.aligned = aligned
|
|
|
|
def forward(self, features, rois):
|
|
num_imgs = features.size(0)
|
|
num_rois = rois.size(0)
|
|
rel_roi_points = generate_grid(
|
|
num_rois, self.output_size, device=rois.device)
|
|
|
|
if torch.onnx.is_in_onnx_export():
|
|
rel_img_points = rel_roi_point_to_rel_img_point(
|
|
rois, rel_roi_points, features, self.spatial_scale)
|
|
rel_img_points = rel_img_points.reshape(num_imgs, -1,
|
|
*rel_img_points.shape[1:])
|
|
point_feats = point_sample(
|
|
features, rel_img_points, align_corners=not self.aligned)
|
|
point_feats = point_feats.transpose(1, 2)
|
|
else:
|
|
point_feats = []
|
|
for batch_ind in range(num_imgs):
|
|
|
|
feat = features[batch_ind].unsqueeze(0)
|
|
inds = (rois[:, 0].long() == batch_ind)
|
|
if inds.any():
|
|
rel_img_points = rel_roi_point_to_rel_img_point(
|
|
rois[inds], rel_roi_points[inds], feat,
|
|
self.spatial_scale).unsqueeze(0)
|
|
point_feat = point_sample(
|
|
feat, rel_img_points, align_corners=not self.aligned)
|
|
point_feat = point_feat.squeeze(0).transpose(0, 1)
|
|
point_feats.append(point_feat)
|
|
|
|
point_feats = torch.cat(point_feats, dim=0)
|
|
|
|
channels = features.size(1)
|
|
roi_feats = point_feats.reshape(num_rois, channels, *self.output_size)
|
|
|
|
return roi_feats
|
|
|
|
def __repr__(self):
|
|
format_str = self.__class__.__name__
|
|
format_str += '(output_size={}, spatial_scale={}'.format(
|
|
self.output_size, self.spatial_scale)
|
|
return format_str
|
|
|