|
import io |
|
from PIL import Image as PImage |
|
import numpy as np |
|
from collections import defaultdict |
|
import cv2 |
|
from typing import Tuple, List |
|
from scipy.spatial.distance import cdist |
|
|
|
from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary |
|
from hoho.color_mappings import gestalt_color_mapping, ade20k_color_mapping |
|
|
|
from geom_solver import GeomSolver, my_empty_solution, cheat_the_metric_solution |
|
|
|
|
|
def convert_entry_to_human_readable(entry): |
|
out = {} |
|
already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't'] |
|
for k, v in entry.items(): |
|
if k in already_good: |
|
out[k] = v |
|
continue |
|
if k == 'points3d': |
|
out[k] = read_points3D_binary(fid=io.BytesIO(v)) |
|
if k == 'cameras': |
|
out[k] = read_cameras_binary(fid=io.BytesIO(v)) |
|
if k == 'images': |
|
out[k] = read_images_binary(fid=io.BytesIO(v)) |
|
if k in ['ade20k', 'gestalt']: |
|
out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v] |
|
if k == 'depthcm': |
|
out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']] |
|
return out |
|
|
|
|
|
def predict(entry, visualize=False) -> Tuple[np.ndarray, List[int]]: |
|
|
|
vertices0, edges0 = my_empty_solution() |
|
try: |
|
vertices, edges = GeomSolver().solve(entry) |
|
except: |
|
print('ERROR') |
|
|
|
vertices, edges = cheat_the_metric_solution() |
|
|
|
|
|
|
|
|
|
|
|
|
|
if (len(edges) < 1) and (len(vertices) >= 2): |
|
|
|
edges = edges0 |
|
|
|
if (len(vertices) < 2) or (len(edges) < 1): |
|
|
|
vertices, edges = vertices0, edges0 |
|
|
|
if visualize: |
|
from hoho.viz3d import plot_estimate_and_gt |
|
plot_estimate_and_gt( vertices, |
|
edges, |
|
entry['wf_vertices'], |
|
entry['wf_edges']) |
|
if vertices.shape[-1] != 3: |
|
print("Wrong size") |
|
vertices, edges = vertices0, edges0 |
|
return entry['__key__'], vertices, edges |
|
|