|
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
from torch import nn as nn
|
|
from torch.autograd import Function
|
|
|
|
from ..utils import ext_loader
|
|
from .ball_query import ball_query
|
|
from .knn import knn
|
|
|
|
ext_module = ext_loader.load_ext(
|
|
'_ext', ['group_points_forward', 'group_points_backward'])
|
|
|
|
|
|
class QueryAndGroup(nn.Module):
|
|
"""Groups points with a ball query of radius.
|
|
|
|
Args:
|
|
max_radius (float): The maximum radius of the balls.
|
|
If None is given, we will use kNN sampling instead of ball query.
|
|
sample_num (int): Maximum number of features to gather in the ball.
|
|
min_radius (float, optional): The minimum radius of the balls.
|
|
Default: 0.
|
|
use_xyz (bool, optional): Whether to use xyz.
|
|
Default: True.
|
|
return_grouped_xyz (bool, optional): Whether to return grouped xyz.
|
|
Default: False.
|
|
normalize_xyz (bool, optional): Whether to normalize xyz.
|
|
Default: False.
|
|
uniform_sample (bool, optional): Whether to sample uniformly.
|
|
Default: False
|
|
return_unique_cnt (bool, optional): Whether to return the count of
|
|
unique samples. Default: False.
|
|
return_grouped_idx (bool, optional): Whether to return grouped idx.
|
|
Default: False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
max_radius,
|
|
sample_num,
|
|
min_radius=0,
|
|
use_xyz=True,
|
|
return_grouped_xyz=False,
|
|
normalize_xyz=False,
|
|
uniform_sample=False,
|
|
return_unique_cnt=False,
|
|
return_grouped_idx=False):
|
|
super().__init__()
|
|
self.max_radius = max_radius
|
|
self.min_radius = min_radius
|
|
self.sample_num = sample_num
|
|
self.use_xyz = use_xyz
|
|
self.return_grouped_xyz = return_grouped_xyz
|
|
self.normalize_xyz = normalize_xyz
|
|
self.uniform_sample = uniform_sample
|
|
self.return_unique_cnt = return_unique_cnt
|
|
self.return_grouped_idx = return_grouped_idx
|
|
if self.return_unique_cnt:
|
|
assert self.uniform_sample, \
|
|
'uniform_sample should be True when ' \
|
|
'returning the count of unique samples'
|
|
if self.max_radius is None:
|
|
assert not self.normalize_xyz, \
|
|
'can not normalize grouped xyz when max_radius is None'
|
|
|
|
def forward(self, points_xyz, center_xyz, features=None):
|
|
"""
|
|
Args:
|
|
points_xyz (Tensor): (B, N, 3) xyz coordinates of the features.
|
|
center_xyz (Tensor): (B, npoint, 3) coordinates of the centriods.
|
|
features (Tensor): (B, C, N) Descriptors of the features.
|
|
|
|
Returns:
|
|
Tensor: (B, 3 + C, npoint, sample_num) Grouped feature.
|
|
"""
|
|
|
|
|
|
if self.max_radius is None:
|
|
idx = knn(self.sample_num, points_xyz, center_xyz, False)
|
|
idx = idx.transpose(1, 2).contiguous()
|
|
else:
|
|
idx = ball_query(self.min_radius, self.max_radius, self.sample_num,
|
|
points_xyz, center_xyz)
|
|
|
|
if self.uniform_sample:
|
|
unique_cnt = torch.zeros((idx.shape[0], idx.shape[1]))
|
|
for i_batch in range(idx.shape[0]):
|
|
for i_region in range(idx.shape[1]):
|
|
unique_ind = torch.unique(idx[i_batch, i_region, :])
|
|
num_unique = unique_ind.shape[0]
|
|
unique_cnt[i_batch, i_region] = num_unique
|
|
sample_ind = torch.randint(
|
|
0,
|
|
num_unique, (self.sample_num - num_unique, ),
|
|
dtype=torch.long)
|
|
all_ind = torch.cat((unique_ind, unique_ind[sample_ind]))
|
|
idx[i_batch, i_region, :] = all_ind
|
|
|
|
xyz_trans = points_xyz.transpose(1, 2).contiguous()
|
|
|
|
grouped_xyz = grouping_operation(xyz_trans, idx)
|
|
grouped_xyz_diff = grouped_xyz - \
|
|
center_xyz.transpose(1, 2).unsqueeze(-1)
|
|
if self.normalize_xyz:
|
|
grouped_xyz_diff /= self.max_radius
|
|
|
|
if features is not None:
|
|
grouped_features = grouping_operation(features, idx)
|
|
if self.use_xyz:
|
|
|
|
new_features = torch.cat([grouped_xyz_diff, grouped_features],
|
|
dim=1)
|
|
else:
|
|
new_features = grouped_features
|
|
else:
|
|
assert (self.use_xyz
|
|
), 'Cannot have not features and not use xyz as a feature!'
|
|
new_features = grouped_xyz_diff
|
|
|
|
ret = [new_features]
|
|
if self.return_grouped_xyz:
|
|
ret.append(grouped_xyz)
|
|
if self.return_unique_cnt:
|
|
ret.append(unique_cnt)
|
|
if self.return_grouped_idx:
|
|
ret.append(idx)
|
|
if len(ret) == 1:
|
|
return ret[0]
|
|
else:
|
|
return tuple(ret)
|
|
|
|
|
|
class GroupAll(nn.Module):
|
|
"""Group xyz with feature.
|
|
|
|
Args:
|
|
use_xyz (bool): Whether to use xyz.
|
|
"""
|
|
|
|
def __init__(self, use_xyz: bool = True):
|
|
super().__init__()
|
|
self.use_xyz = use_xyz
|
|
|
|
def forward(self,
|
|
xyz: torch.Tensor,
|
|
new_xyz: torch.Tensor,
|
|
features: torch.Tensor = None):
|
|
"""
|
|
Args:
|
|
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
|
|
new_xyz (Tensor): new xyz coordinates of the features.
|
|
features (Tensor): (B, C, N) features to group.
|
|
|
|
Returns:
|
|
Tensor: (B, C + 3, 1, N) Grouped feature.
|
|
"""
|
|
grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
|
if features is not None:
|
|
grouped_features = features.unsqueeze(2)
|
|
if self.use_xyz:
|
|
|
|
new_features = torch.cat([grouped_xyz, grouped_features],
|
|
dim=1)
|
|
else:
|
|
new_features = grouped_features
|
|
else:
|
|
new_features = grouped_xyz
|
|
|
|
return new_features
|
|
|
|
|
|
class GroupingOperation(Function):
|
|
"""Group feature with given index."""
|
|
|
|
@staticmethod
|
|
def forward(ctx, features: torch.Tensor,
|
|
indices: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
features (Tensor): (B, C, N) tensor of features to group.
|
|
indices (Tensor): (B, npoint, nsample) the indices of
|
|
features to group with.
|
|
|
|
Returns:
|
|
Tensor: (B, C, npoint, nsample) Grouped features.
|
|
"""
|
|
features = features.contiguous()
|
|
indices = indices.contiguous()
|
|
|
|
B, nfeatures, nsample = indices.size()
|
|
_, C, N = features.size()
|
|
output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
|
|
|
|
ext_module.group_points_forward(B, C, N, nfeatures, nsample, features,
|
|
indices, output)
|
|
|
|
ctx.for_backwards = (indices, N)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx,
|
|
grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
grad_out (Tensor): (B, C, npoint, nsample) tensor of the gradients
|
|
of the output from forward.
|
|
|
|
Returns:
|
|
Tensor: (B, C, N) gradient of the features.
|
|
"""
|
|
idx, N = ctx.for_backwards
|
|
|
|
B, C, npoint, nsample = grad_out.size()
|
|
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
|
|
|
|
grad_out_data = grad_out.data.contiguous()
|
|
ext_module.group_points_backward(B, C, N, npoint, nsample,
|
|
grad_out_data, idx,
|
|
grad_features.data)
|
|
return grad_features, None
|
|
|
|
|
|
grouping_operation = GroupingOperation.apply
|
|
|