DRAK / geom_solver.py
rozumden's picture
Update geom_solver.py
2b6428e verified
raw
history blame
9.75 kB
import numpy as np
from pytorch3d.ops import ball_query
from helpers import *
from handcrafted_solution import convert_entry_to_human_readable
import cv2
import hoho
import itertools
import torch
from pytorch3d.renderer import PerspectiveCameras
from hoho.color_mappings import gestalt_color_mapping
def my_empty_solution():
return np.zeros((18,3)), [(0, 0)]
class GeomSolver(object):
def __init__(self):
self.min_vertices = 18
self.kmeans_th = 50
self.point_dist_th = 25
self.clr_th = 2.5
self.device = 'cuda:0'
self.return_edges = False
self.mean_fixed = False
def cluster_points(self, point_types):
point_colors = []
for point_type in point_types:
point_colors.append(np.array(gestalt_color_mapping[point_type]))
dist_points = np.zeros((self.verts.shape[0], ))
visible_counts = np.zeros((self.verts.shape[0], ), dtype=int)
proj_uv = []
for ki in range(len(self.gestalt_to_colmap_cams)):
if self.broken_cams[ki]:
proj_uv.append(([], []))
continue
cki = self.gestalt_to_colmap_cams[ki]
gest = self.gests[ki]
vert_mask = 0
for point_color in point_colors:
my_mask = cv2.inRange(gest, point_color-self.clr_th, point_color+self.clr_th)
vert_mask = vert_mask + my_mask
vert_mask = (vert_mask > 0).astype(np.uint8)
dist = cv2.distanceTransform(1-vert_mask, cv2.DIST_L2, 3)
# dist[dist > 100] = 100
# ndist = np.zeros_like(dist)
# ndist = cv2.normalize(dist, ndist, 0, 1.0, cv2.NORM_MINMAX)
in_this_image = np.array([cki in p.image_ids for p in self.points3D.values()])
uv = torch.round(self.pyt_cameras[ki].transform_points(self.verts)[:, :2]).cpu().numpy().astype(int)
uv_inl = (uv[:, 0] >= 0) * (uv[:, 1] >= 0) * (uv[:, 0] < self.width) * (uv[:, 1] < self.height) * in_this_image
proj_uv.append((uv, uv_inl))
uv = uv[uv_inl]
dist_points[uv_inl] += dist[uv[:,1], uv[:,0]]
visible_counts[uv_inl] += 1
selected_points = (dist_points / (visible_counts + 1e-6)) <= self.point_dist_th
selected_points[visible_counts < 1] = False
pnts = torch.from_numpy(self.xyz[selected_points].astype(np.float32))[None]
bdists, inds, nn = ball_query(pnts, pnts, K=3, radius=30)
dense_pnts = (bdists[0] > 0).sum(1) == 2
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 200, 0.3)
flags = cv2.KMEANS_RANDOM_CENTERS
centers = np.zeros((0, 3))
assigned_points = []
if len(self.xyz[selected_points][dense_pnts]) == 0:
return centers, assigned_points
for tempi in range(1, 20):
retval, temp_bestLabels, temp_centers = cv2.kmeans(self.xyz[selected_points][dense_pnts].astype(np.float32), tempi, None, criteria, 200,flags)
cpnts = torch.from_numpy(temp_centers.astype(np.float32))[None]
bdists, inds, nn = ball_query(cpnts, cpnts, K=1, radius=100)
if bdists.max() > 0:
closest_nn = (bdists[bdists>0].min()**0.5).item()
else:
closest_nn = self.kmeans_th
if closest_nn < self.kmeans_th:
break
centers, bestLabels = temp_centers, temp_bestLabels
if centers.shape[0] == 0:
centers, bestLabels = temp_centers, temp_bestLabels
point_inds = np.arange(self.xyz.shape[0])
for ci in range(centers.shape[0]):
assigned_inds = point_inds[selected_points][dense_pnts][bestLabels[:,0] == ci]
assigned_points.append(assigned_inds)
return centers, assigned_points
def process_vertices(self):
human_entry = self.human_entry
col_cams = [hoho.Rt_to_eye_target(human_entry['ade20k'][0], to_K(*human_entry['cameras'][1].params), quaternion_to_rotation_matrix(colmap_img.qvec), colmap_img.tvec) for colmap_img in human_entry['images'].values()]
eye, target, up, fov = col_cams[0]
cameras, images, self.points3D = human_entry['cameras'], human_entry['images'], human_entry['points3d']
colmap_cameras_tf = list(human_entry['images'].keys())
self.xyz = np.stack([p.xyz for p in self.points3D.values()])
color = np.stack([p.rgb for p in self.points3D.values()])
self.gests = [np.array(gest0) for gest0 in human_entry['gestalt']]
gestalt_camcet = np.stack([eye for eye, target, up, fov in itertools.starmap(hoho.Rt_to_eye_target, zip(*[human_entry[k] for k in 'ade20k K R t'.split()]))])
col_camcet = np.stack([eye for eye, target, up, fov in col_cams])
self.gestalt_to_colmap_cams = [colmap_cameras_tf[np.argmin(((gcam - col_camcet)**2).sum(1)**0.5)] for gcam in gestalt_camcet]
self.broken_cams = np.array([np.min(((gcam - col_camcet)**2).sum(1)**0.5) for gcam in gestalt_camcet]) > 300
self.height, self.width = cameras[1].height, cameras[1].width
N = len(self.gestalt_to_colmap_cams)
K = to_K(*human_entry['cameras'][1].params)[None].repeat(N, 0)
R = np.stack([quaternion_to_rotation_matrix(human_entry['images'][self.gestalt_to_colmap_cams[ind]].qvec) for ind in range(N)])
T = np.stack([human_entry['images'][self.gestalt_to_colmap_cams[ind]].tvec for ind in range(N)])
R = np.linalg.inv(R)
image_size = torch.Tensor([self.height, self.width]).repeat(N, 1)
self.pyt_cameras = PerspectiveCameras(device=self.device, R=R, T=T, in_ndc=False, focal_length=K[:, 0, :1], principal_point=K[:, :2, 2], image_size=image_size)
self.verts = torch.from_numpy(self.xyz.astype(np.float32)).to(self.device)
centers_apex, assigned_apex = self.cluster_points(['apex'])
centers_eave, assigned_eave = self.cluster_points(['eave_end_point'])
centers = np.concatenate((centers_apex, centers_eave))
self.assigned_points = assigned_apex + assigned_eave
self.is_apex = np.zeros((centers.shape[0], )).astype(int)
self.is_apex[:centers_apex.shape[0]] = 1
z_th = centers[:,-1].min() - 50
self.wf_center = self.xyz[self.xyz[:,-1] > z_th].mean(0)
self.wf_center[-1] = centers[:, -1].mean()
self.vertices = centers
if self.broken_cams.any():
vertices = centers
print("There are broken cams.")
else:
nvert = centers.shape[0]
if self.mean_fixed:
added_one = (self.min_vertices * self.wf_center - self.vertices.sum(0)) / (self.min_vertices - nvert)
else:
added_one = self.wf_center
added = added_one[None].repeat(self.min_vertices - nvert,0)
vertices = np.concatenate((self.vertices, added))
self.vertices_aug = vertices
def process_edges(self):
N = len(self.gests)
image_ids = np.array([p.id for p in self.points3D.values()])
center_visibility = [set(np.concatenate([self.points3D[image_ids[pind]].image_ids for pind in ass_item])) for ass_item in self.assigned_points]
pyt_centers = torch.from_numpy(self.vertices.astype(np.float32)).to(self.device)
edge_dists = []
uvs = []
edge_types = {0 : ['eave'], 1 : ['rake', 'valley'], 2 : ['ridge']}
for ki in range(N):
gest = self.gests[ki]
edge_masks = {}
per_type_dists = {}
for etype in edge_types:
edge_mask = 0
for edge_class in edge_types[etype]:
edge_color = np.array(gestalt_color_mapping[edge_class])
mask = cv2.morphologyEx(cv2.inRange(gest,
edge_color-self.clr_th,
edge_color+self.clr_th),
cv2.MORPH_DILATE, np.ones((3, 3)))
edge_mask += mask
edge_mask = (edge_mask > 0).astype(np.uint8)
edge_masks[etype] = edge_mask
dist = cv2.distanceTransform(1-edge_mask, cv2.DIST_L2, 3)
per_type_dists[etype] = dist
edge_dists.append(per_type_dists)
uv = torch.round(self.pyt_cameras[ki].transform_points(pyt_centers)[:, :2]).cpu().numpy().astype(int)
uv_inl = (uv[:, 0] >= 0) * (uv[:, 1] >= 0) * (uv[:, 0] < self.width) * (uv[:, 1] < self.height)
uv = uv[uv_inl]
uvs.append(uv)
edges = []
thresholds_min_mean = {0 : [5, 7], 1 : [9, 25], 2: [30, 1000]}
for i in range(pyt_centers.shape[0]):
for j in range(i+1, pyt_centers.shape[0]):
etype = (self.is_apex[i] + self.is_apex[j])
points_inter = pyt_centers[i][None] + torch.linspace(0, 1, 20)[:, None].to(self.device) * (pyt_centers[j][None] - pyt_centers[i][None])
min_mean_dist = 1000
all_dists = []
best_ki = -1
best_uvi = -1
for ki in range(N):
cki = self.gestalt_to_colmap_cams[ki]
if not ( (cki in center_visibility[i]) or (cki in center_visibility[j]) ):
continue
if self.broken_cams[ki]:
continue
uvi = torch.round(self.pyt_cameras[ki].transform_points(points_inter)[:, :2]).cpu().numpy().astype(int)
if (uvi <= 0).any() or (uvi[:,0] >= self.width).any() or (uvi[:,1] >= self.height).any():
continue
mean_dist = edge_dists[ki][etype][uvi[:,1], uvi[:,0]].mean()
all_dists.append(mean_dist)
if mean_dist < min_mean_dist:
min_mean_dist = mean_dist
best_ki = ki
best_uvi = uvi
if best_ki == -1:
continue
ths = thresholds_min_mean[etype]
if min_mean_dist < ths[0] and np.mean(all_dists) < ths[1]:
edges.append((i,j))
if len(edges) == 0:
edges.append((0, 0))
return edges
def solve(self, entry, visualize=False):
human_entry = convert_entry_to_human_readable(entry)
self.human_entry = human_entry
self.process_vertices()
vertices = self.vertices_aug
if self.return_edges:
edges = self.process_edges()
else:
edges = [(0, 0)]
if visualize:
from hoho.viz3d import plot_estimate_and_gt
plot_estimate_and_gt(vertices, [(0,0)], self.human_entry['wf_vertices'], self.human_entry['wf_edges'])
return vertices, edges