Files
marf/ifield/data/common/scan.py
2025-01-09 15:43:11 +01:00

769 lines
30 KiB
Python

from ...utils.helpers import compose
from . import points
from .h5_dataclasses import H5Dataclass, H5Array, H5ArrayNoSlice, TransformableDataclassMixin
from methodtools import lru_cache
from sklearn.neighbors import BallTree
import faiss
from trimesh import Trimesh
from typing import Iterable
from typing import Optional, TypeVar
import mesh_to_sdf
import mesh_to_sdf.scan as sdf_scan
import numpy as np
import trimesh
import trimesh.transformations as T
import warnings
__doc__ = """
Here are some helper types for data.
"""
_T = TypeVar("T")
class InvalidateLRUOnWriteMixin:
def __setattr__(self, key, value):
if not key.startswith("__wire|"):
for attr in dir(self):
if attr.startswith("__wire|"):
getattr(self, attr).cache_clear()
return super().__setattr__(key, value)
def lru_property(func):
return lru_cache(maxsize=1)(property(func))
class SingleViewScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
points_hit : H5ArrayNoSlice # (N, 3)
normals_hit : Optional[H5ArrayNoSlice] # (N, 3)
points_miss : H5ArrayNoSlice # (M, 3)
distances_miss : Optional[H5ArrayNoSlice] # (M)
colors_hit : Optional[H5ArrayNoSlice] # (N, 3)
colors_miss : Optional[H5ArrayNoSlice] # (M, 3)
uv_hits : Optional[H5ArrayNoSlice] # (H, W) dtype=bool
uv_miss : Optional[H5ArrayNoSlice] # (H, W) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backfaces)
cam_pos : H5ArrayNoSlice # (3)
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
out = self if inplace else self.copy(deep=False)
out.points_hit = T.transform_points(self.points_hit, mat4)
out.normals_hit = T.transform_points(self.normals_hit, mat4) if self.normals_hit is not None else None
out.points_miss = T.transform_points(self.points_miss, mat4)
out.distances_miss = self.distances_miss * scale_xyz
out.cam_pos = T.transform_points(self.points_cam, mat4)[-1]
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
return out
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
assert not self.has_miss_distances
if not self.is_hitting:
raise ValueError("No hits to compute the ray distance towards")
out = self.copy(deep=deep) if copy else self
out.distances_miss \
= distance_from_rays_to_point_cloud(
ray_origins = out.points_cam,
ray_dirs = out.ray_dirs_miss,
points = out.points_hit,
).astype(out.points_cam.dtype)
return out
@lru_property
def points(self) -> np.ndarray: # (N+M+1, 3)
return np.concatenate((
self.points_hit,
self.points_miss,
self.points_cam,
))
@lru_property
def uv_points(self) -> np.ndarray: # (N+M+1, 3)
if not self.has_uv: raise ValueError
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.points_hit.dtype)
out[self.uv_hits, :] = self.points_hit
out[self.uv_miss, :] = self.points_miss
return out
@lru_property
def uv_normals(self) -> np.ndarray: # (N+M+1, 3)
if not self.has_uv: raise ValueError
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.normals_hit.dtype)
out[self.uv_hits, :] = self.normals_hit
return out
@lru_property
def points_cam(self) -> Optional[np.ndarray]: # (1, 3)
if self.cam_pos is None: return None
return self.cam_pos[None, :]
@lru_property
def points_hit_centroid(self) -> np.ndarray:
return self.points_hit.mean(axis=0)
@lru_property
def points_hit_std(self) -> np.ndarray:
return self.points_hit.std(axis=0)
@lru_property
def is_hitting(self) -> bool:
return len(self.points_hit) > 0
@lru_property
def is_empty(self) -> bool:
return not (len(self.points_hit) or len(self.points_miss))
@lru_property
def has_colors(self) -> bool:
return self.colors_hit is not None or self.colors_miss is not None
@lru_property
def has_normals(self) -> bool:
return self.normals_hit is not None
@lru_property
def has_uv(self) -> bool:
return self.uv_hits is not None
@lru_property
def has_miss_distances(self) -> bool:
return self.distances_miss is not None
@lru_property
def xyzrgb_hit(self) -> np.ndarray: # (N, 6)
if self.colors_hit is None: raise ValueError
return np.concatenate([self.points_hit, self.colors_hit], axis=1)
@lru_property
def xyzrgb_miss(self) -> np.ndarray: # (M, 6)
if self.colors_miss is None: raise ValueError
return np.concatenate([self.points_miss, self.colors_miss], axis=1)
@lru_property
def ray_dirs_hit(self) -> np.ndarray: # (N, 3)
out = self.points_hit - self.points_cam
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
return out
@lru_property
def ray_dirs_miss(self) -> np.ndarray: # (N, 3)
out = self.points_miss - self.points_cam
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
return out
@classmethod
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewScan":
if "phi" not in kw and not "theta" in kw:
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
scan = sample_single_view_scan_from_mesh(mesh, **kw)
if compute_miss_distances and scan.is_hitting:
scan.compute_miss_distances()
return scan
def to_uv_scan(self) -> "SingleViewUVScan":
return SingleViewUVScan.from_scan(self)
@classmethod
def from_uv_scan(self, uvscan: "SingleViewUVScan") -> "SingleViewUVScan":
return uvscan.to_scan()
# The same, but with support for pagination (should have been this way since the start...)
class SingleViewUVScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
# B may be (N) or (H, W), the latter may be flattened
hits : H5Array # (*B) dtype=bool
miss : H5Array # (*B) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backface hits)
points : H5Array # (*B, 3) on far plane if miss, NaN if neither hit or miss
normals : Optional[H5Array] # (*B, 3) NaN if not hit
colors : Optional[H5Array] # (*B, 3)
distances : Optional[H5Array] # (*B) NaN if not miss
cam_pos : Optional[H5ArrayNoSlice] # (3) or (*B, 3)
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
@classmethod
def from_scan(cls, scan: SingleViewScan):
if not scan.has_uv:
raise ValueError("Scan cloud has no UV data")
hits, miss = scan.uv_hits, scan.uv_miss
dtype = scan.points_hit.dtype
assert hits.ndim in (1, 2), hits.ndim
assert hits.shape == miss.shape, (hits.shape, miss.shape)
points = np.full((*hits.shape, 3), np.nan, dtype=dtype)
points[hits, :] = scan.points_hit
points[miss, :] = scan.points_miss
normals = None
if scan.has_normals:
normals = np.full((*hits.shape, 3), np.nan, dtype=dtype)
normals[hits, :] = scan.normals_hit
distances = None
if scan.has_miss_distances:
distances = np.full(hits.shape, np.nan, dtype=dtype)
distances[miss] = scan.distances_miss
colors = None
if scan.has_colors:
colors = np.full((*hits.shape, 3), np.nan, dtype=dtype)
if scan.colors_hit is not None:
colors[hits, :] = scan.colors_hit
if scan.colors_miss is not None:
colors[miss, :] = scan.colors_miss
return cls(
hits = hits,
miss = miss,
points = points,
normals = normals,
colors = colors,
distances = distances,
cam_pos = scan.cam_pos,
cam_mat4 = scan.cam_mat4,
proj_mat4 = scan.proj_mat4,
transforms = scan.transforms,
)
def to_scan(self) -> "SingleViewScan":
if not self.is_single_view: raise ValueError
return SingleViewScan(
points_hit = self.points [self.hits, :],
points_miss = self.points [self.miss, :],
normals_hit = self.normals [self.hits, :] if self.has_normals else None,
distances_miss = self.distances[self.miss] if self.has_miss_distances else None,
colors_hit = self.colors [self.hits, :] if self.has_colors else None,
colors_miss = self.colors [self.miss, :] if self.has_colors else None,
uv_hits = self.hits,
uv_miss = self.miss,
cam_pos = self.cam_pos,
cam_mat4 = self.cam_mat4,
proj_mat4 = self.proj_mat4,
transforms = self.transforms,
)
def to_mesh(self) -> trimesh.Trimesh:
faces: list[(tuple[int, int],)*3] = []
for x in range(self.hits.shape[0]-1):
for y in range(self.hits.shape[1]-1):
c11 = x, y
c12 = x, y+1
c22 = x+1, y+1
c21 = x+1, y
n = sum(map(self.hits.__getitem__, (c11, c12, c22, c21)))
if n == 3:
faces.append((*filter(self.hits.__getitem__, (c11, c12, c22, c21)),))
elif n == 4:
faces.append((c11, c12, c22))
faces.append((c11, c22, c21))
xy2idx = {c:i for i, c in enumerate(set(k for j in faces for k in j))}
assert self.colors is not None
return trimesh.Trimesh(
vertices = [self.points[i] for i in xy2idx.keys()],
vertex_colors = [self.colors[i] for i in xy2idx.keys()] if self.colors is not None else None,
faces = [tuple(xy2idx[i] for i in face) for face in faces],
)
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
unflat = self.hits.shape
flat = np.product(unflat)
out = self if inplace else self.copy(deep=False)
out.points = T.transform_points(self.points .reshape((*flat, 3)), mat4).reshape((*unflat, 3))
out.normals = T.transform_points(self.normals.reshape((*flat, 3)), mat4).reshape((*unflat, 3)) if self.normals_hit is not None else None
out.distances = self.distances_miss * scale_xyz
out.cam_pos = T.transform_points(self.cam_pos[None, ...], mat4)[0]
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
return out
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False, surface_points: Optional[np.ndarray] = None) -> _T:
assert not self.has_miss_distances
shape = self.hits.shape
out = self.copy(deep=deep) if copy else self
out.distances = np.zeros(shape, dtype=self.points.dtype)
if self.is_hitting:
out.distances[self.miss] \
= distance_from_rays_to_point_cloud(
ray_origins = self.cam_pos_unsqueezed_miss,
ray_dirs = self.ray_dirs_miss,
points = surface_points if surface_points is not None else self.points[self.hits],
)
return out
def fill_missing_points(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
"""
Fill in missing points as hitting the far plane.
"""
if not self.is_2d:
raise ValueError("Cannot fill missing points for non-2d scan!")
if not self.is_single_view:
raise ValueError("Cannot fill missing points for non-single-view scans!")
if self.cam_mat4 is None:
raise ValueError("cam_mat4 is None")
if self.proj_mat4 is None:
raise ValueError("proj_mat4 is None")
uv = np.argwhere(self.missing).astype(self.points.dtype)
uv[:, 0] /= (self.missing.shape[1] - 1) / 2
uv[:, 1] /= (self.missing.shape[0] - 1) / 2
uv -= 1
uv = np.stack((
uv[:, 1],
-uv[:, 0],
np.ones(uv.shape[0]), # far clipping plane
np.ones(uv.shape[0]), # homogeneous coordinate
), axis=-1)
uv = uv @ (self.cam_mat4 @ np.linalg.inv(self.proj_mat4)).T
out = self.copy(deep=deep) if copy else self
out.points[self.missing, :] = uv[:, :3] / uv[:, 3][:, None]
return out
@lru_property
def is_hitting(self) -> bool:
return np.any(self.hits)
@lru_property
def has_colors(self) -> bool:
return not self.colors is None
@lru_property
def has_normals(self) -> bool:
return not self.normals is None
@lru_property
def has_miss_distances(self) -> bool:
return not self.distances is None
@lru_property
def any_missing(self) -> bool:
return np.any(self.missing)
@lru_property
def has_missing(self) -> bool:
return self.any_missing and not np.any(np.isnan(self.points[self.missing]))
@lru_property
def cam_pos_unsqueezed(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos
else:
cam_pos = self.cam_pos
for _ in range(self.hits.ndim):
cam_pos = cam_pos[None, ...]
return cam_pos
@lru_property
def cam_pos_unsqueezed_hit(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos[self.hits, :]
else:
return self.cam_pos[None, :]
@lru_property
def cam_pos_unsqueezed_miss(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos[self.miss, :]
else:
return self.cam_pos[None, :]
@lru_property
def ray_dirs(self) -> H5Array:
return (self.points - self.cam_pos_unsqueezed) * (1 / self.depths[..., None])
@lru_property
def ray_dirs_hit(self) -> H5Array:
out = self.points[self.hits, :] - self.cam_pos_unsqueezed_hit
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
return out
@lru_property
def ray_dirs_miss(self) -> H5Array:
out = self.points[self.miss, :] - self.cam_pos_unsqueezed_miss
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
return out
@lru_property
def depths(self) -> H5Array:
return np.linalg.norm(self.points - self.cam_pos_unsqueezed, axis=-1)
@lru_property
def missing(self) -> H5Array:
return ~(self.hits | self.miss)
@classmethod
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
if "phi" not in kw and not "theta" in kw:
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
scan = sample_single_view_scan_from_mesh(mesh, **kw).to_uv_scan()
if compute_miss_distances:
scan.compute_miss_distances()
assert scan.is_2d
return scan
@classmethod
def from_mesh_sphere_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
scan = sample_sphere_view_scan_from_mesh(mesh, **kw)
if compute_miss_distances:
surface_points = None
if scan.hits.sum() > mesh.vertices.shape[0]:
surface_points = mesh.vertices.astype(scan.points.dtype)
if not kw.get("no_unit_sphere", False):
translation, scale = compute_unit_sphere_transform(mesh, dtype=scan.points.dtype)
surface_points = (surface_points + translation) * scale
scan.compute_miss_distances(surface_points=surface_points)
assert scan.is_flat
return scan
def flatten_and_permute_(self: _T, copy=False) -> _T: # inplace by default
n_items = np.product(self.hits.shape)
permutation = np.random.permutation(n_items)
out = self.copy(deep=False) if copy else self
out.hits = out.hits .reshape((n_items, ))[permutation]
out.miss = out.miss .reshape((n_items, ))[permutation]
out.points = out.points .reshape((n_items, 3))[permutation, :]
out.normals = out.normals .reshape((n_items, 3))[permutation, :] if out.has_normals else None
out.colors = out.colors .reshape((n_items, 3))[permutation, :] if out.has_colors else None
out.distances = out.distances.reshape((n_items, ))[permutation] if out.has_miss_distances else None
return out
@property
def is_single_view(self) -> bool:
return np.product(self.cam_pos.shape[:-1]) == 1 if not self.cam_pos is None else True
@property
def is_flat(self) -> bool:
return len(self.hits.shape) == 1
@property
def is_2d(self) -> bool:
return len(self.hits.shape) == 2
# transforms can be found in pytorch3d.transforms and in open3d
# and in trimesh.transformations
def sample_single_view_scans_from_mesh(
mesh : Trimesh,
*,
n_batches : int,
scan_resolution : int = 400,
compute_normals : bool = False,
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
camera_distance : float = 2,
no_filter_backhits : bool = False,
) -> Iterable[SingleViewScan]:
normalized_mesh_cache = []
for _ in range(n_batches):
theta, phi = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
yield sample_single_view_scan_from_mesh(
mesh = mesh,
phi = phi,
theta = theta,
_mesh_is_normalized = False,
scan_resolution = scan_resolution,
compute_normals = compute_normals,
fov = fov,
camera_distance = camera_distance,
no_filter_backhits = no_filter_backhits,
_mesh_cache = normalized_mesh_cache,
)
def sample_single_view_scan_from_mesh(
mesh : Trimesh,
*,
phi : float,
theta : float,
scan_resolution : int = 200,
compute_normals : bool = False,
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
camera_distance : float = 2,
no_filter_backhits : bool = False,
no_unit_sphere : bool = False,
dtype : type = np.float32,
_mesh_cache : Optional[list] = None, # provide a list if mesh is reused
) -> SingleViewScan:
# scale and center to unit sphere
is_cache = isinstance(_mesh_cache, list)
if is_cache and _mesh_cache and _mesh_cache[0] is mesh:
_, mesh, translation, scale = _mesh_cache
else:
if is_cache:
if _mesh_cache:
_mesh_cache.clear()
_mesh_cache.append(mesh)
translation, scale = compute_unit_sphere_transform(mesh)
mesh = mesh_to_sdf.scale_to_unit_sphere(mesh)
if is_cache:
_mesh_cache.extend((mesh, translation, scale))
z_near = 1
z_far = 3
cam_mat4 = sdf_scan.get_camera_transform_looking_at_origin(phi, theta, camera_distance=camera_distance)
cam_pos = cam_mat4 @ np.array([0, 0, 0, 1])
scan = sdf_scan.Scan(mesh,
camera_transform = cam_mat4,
resolution = scan_resolution,
calculate_normals = compute_normals,
fov = fov,
z_near = z_near,
z_far = z_far,
no_flip_backfaced_normals = True
)
# all the scan rays that hit the far plane, based on sdf_scan.Scan.__init__
misses = np.argwhere(scan.depth_buffer == 0)
points_miss = np.ones((misses.shape[0], 4))
points_miss[:, [1, 0]] = misses.astype(float) / (scan_resolution -1) * 2 - 1
points_miss[:, 1] *= -1
points_miss[:, 2] = 1 # far plane in clipping space
points_miss = points_miss @ (cam_mat4 @ np.linalg.inv(scan.projection_matrix)).T
points_miss /= points_miss[:, 3][:, np.newaxis]
points_miss = points_miss[:, :3]
uv_hits = scan.depth_buffer != 0
uv_miss = ~uv_hits
if not no_filter_backhits:
if not compute_normals:
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
# inner product
mask = np.einsum('ij,ij->i', scan.points - cam_pos[:3][None, :], scan.normals) < 0
scan.points = scan.points [mask, :]
scan.normals = scan.normals[mask, :]
uv_hits[uv_hits] = mask
transforms = {}
# undo unit-sphere transform
if no_unit_sphere:
scan.points = scan.points * (1 / scale) - translation
points_miss = points_miss * (1 / scale) - translation
cam_pos[:3] = cam_pos[:3] * (1 / scale) - translation
cam_mat4[:3, :] *= 1 / scale
cam_mat4[:3, 3] -= translation
transforms["unit_sphere"] = T.scale_and_translate(scale=scale, translate=translation)
transforms["model"] = np.eye(4)
else:
transforms["model"] = np.linalg.inv(T.scale_and_translate(scale=scale, translate=translation))
transforms["unit_sphere"] = np.eye(4)
return SingleViewScan(
normals_hit = scan.normals .astype(dtype),
points_hit = scan.points .astype(dtype),
points_miss = points_miss .astype(dtype),
distances_miss = None,
colors_hit = None,
colors_miss = None,
uv_hits = uv_hits .astype(bool),
uv_miss = uv_miss .astype(bool),
cam_pos = cam_pos[:3] .astype(dtype),
cam_mat4 = cam_mat4 .astype(dtype),
proj_mat4 = scan.projection_matrix .astype(dtype),
transforms = {k:v.astype(dtype) for k, v in transforms.items()},
)
def sample_sphere_view_scan_from_mesh(
mesh : Trimesh,
*,
sphere_points : int = 4000, # resulting rays are n*(n-1)
compute_normals : bool = False,
no_filter_backhits : bool = False,
no_unit_sphere : bool = False,
no_permute : bool = False,
dtype : type = np.float32,
**kw,
) -> SingleViewUVScan:
translation, scale = compute_unit_sphere_transform(mesh, dtype=dtype)
# get unit-sphere points, then transform to model space
two_sphere = generate_equidistant_sphere_rays(sphere_points, **kw).astype(dtype) # (n*(n-1), 2, 3)
two_sphere = two_sphere / scale - translation # we transform after cache lookup
if mesh.ray.__class__.__module__.split(".")[-1] != "ray_pyembree":
warnings.warn("Pyembree not found, the ray-tracing will be SLOW!")
(
locations,
index_ray,
index_tri,
) = mesh.ray.intersects_location(
two_sphere[:, 0, :],
two_sphere[:, 1, :] - two_sphere[:, 0, :], # direction, not target coordinate
multiple_hits=False,
)
if compute_normals:
location_normals = mesh.face_normals[index_tri]
batch = two_sphere.shape[:1]
hits = np.zeros((*batch,), dtype=np.bool)
miss = np.ones((*batch,), dtype=np.bool)
cam_pos = two_sphere[:, 0, :]
intersections = two_sphere[:, 1, :] # far-plane, effectively
normals = np.zeros((*batch, 3), dtype=dtype)
index_ray_front = index_ray
if not no_filter_backhits:
if not compute_normals:
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
mask = ((intersections[index_ray] - cam_pos[index_ray]) * location_normals).sum(axis=-1) <= 0
index_ray_front = index_ray[mask]
hits[index_ray_front] = True
miss[index_ray] = False
intersections[index_ray] = locations
normals[index_ray] = location_normals
if not no_permute:
assert len(batch) == 1, batch
permutation = np.random.permutation(*batch)
hits = hits [permutation]
miss = miss [permutation]
intersections = intersections[permutation, :]
normals = normals [permutation, :]
cam_pos = cam_pos [permutation, :]
# apply unit sphere transform
if not no_unit_sphere:
intersections = (intersections + translation) * scale
cam_pos = (cam_pos + translation) * scale
return SingleViewUVScan(
hits = hits,
miss = miss,
points = intersections,
normals = normals,
colors = None, # colors
distances = None,
cam_pos = cam_pos,
cam_mat4 = None,
proj_mat4 = None,
transforms = {},
)
def distance_from_rays_to_point_cloud(
ray_origins : np.ndarray, # (*A, 3)
ray_dirs : np.ndarray, # (*A, 3)
points : np.ndarray, # (*B, 3)
dirs_normalized : bool = False,
n_steps : int = 40,
) -> np.ndarray: # (A)
# anything outside of this volume will never constribute to the result
max_norm = max(
np.linalg.norm(ray_origins, axis=-1).max(),
np.linalg.norm(points, axis=-1).max(),
) * 1.02
if not dirs_normalized:
ray_dirs = ray_dirs / np.linalg.norm(ray_dirs, axis=-1)[..., None]
# deal with single-view clouds
if ray_origins.shape != ray_dirs.shape:
ray_origins = np.broadcast_to(ray_origins, ray_dirs.shape)
n_points = np.product(points.shape[:-1])
use_faiss = n_points > 160000*4
if not use_faiss:
index = BallTree(points)
else:
# http://ann-benchmarks.com/index.html
assert np.issubdtype(points.dtype, np.float32)
assert np.issubdtype(ray_origins.dtype, np.float32)
assert np.issubdtype(ray_dirs.dtype, np.float32)
index = faiss.index_factory(points.shape[-1], "NSG32,Flat") # https://github.com/facebookresearch/faiss/wiki/The-index-factory
index.nprobe = 5 # 10 # default is 1
index.train(points)
index.add(points)
if not use_faiss:
min_d, min_n = index.query(ray_origins, k=1, return_distance=True)
else:
min_d, min_n = index.search(ray_origins, k=1)
min_d = np.sqrt(min_d)
acc_d = min_d.copy()
for step in range(1, n_steps+1):
query_points = ray_origins + acc_d * ray_dirs
if max_norm is not None:
qmask = np.linalg.norm(query_points, axis=-1) < max_norm
if not qmask.any(): break
query_points = query_points[qmask]
else:
qmask = slice(None)
if not use_faiss:
current_d, current_n = index.query(query_points, k=1, return_distance=True)
else:
current_d, current_n = index.search(query_points, k=1)
current_d = np.sqrt(current_d)
if max_norm is not None:
min_d[qmask] = np.minimum(current_d, min_d[qmask])
new_min_mask = min_d[qmask] == current_d
qmask2 = qmask.copy()
qmask2[qmask2] = new_min_mask[..., 0]
min_n[qmask2] = current_n[new_min_mask[..., 0]]
acc_d[qmask] += current_d * 0.25
else:
np.minimum(current_d, min_d, out=min_d)
new_min_mask = min_d == current_d
min_n[new_min_mask] = current_n[new_min_mask]
acc_d += current_d * 0.25
closest_points = points[min_n[:, 0], :] # k=1
distances = np.linalg.norm(np.cross(closest_points - ray_origins, ray_dirs, axis=-1), axis=-1)
return distances
# helpers
@compose(np.array) # make copy to avoid lru cache mutation
@lru_cache(maxsize=1)
def generate_equidistant_sphere_rays(n : int, **kw) -> np.ndarray: # output (n*n(-1)) rays, n may be off
sphere_points = points.generate_equidistant_sphere_points(n=n, **kw)
indices = np.indices((len(sphere_points),))[0] # (N)
# cartesian product
cprod = np.transpose([np.tile(indices, len(indices)), np.repeat(indices, len(indices))]) # (N**2, 2)
# filter repeated combinations
permutations = cprod[cprod[:, 0] != cprod[:, 1], :] # (N*(N-1), 2)
# lookup sphere points
two_sphere = sphere_points[permutations, :] # (N*(N-1), 2, 3)
return two_sphere
def compute_unit_sphere_transform(mesh: Trimesh, *, dtype=type) -> tuple[np.ndarray, float]:
"""
returns translation and scale which mesh_to_sdf applies to meshes before computing their SDF cloud
"""
# the transformation applied by mesh_to_sdf.scale_to_unit_sphere(mesh)
translation = -mesh.bounding_box.centroid
scale = 1 / np.max(np.linalg.norm(mesh.vertices + translation, axis=1))
if dtype is not None:
translation = translation.astype(dtype)
scale = scale .astype(dtype)
return translation, scale