from ...utils.helpers import compose from . import points from .h5_dataclasses import H5Dataclass, H5Array, H5ArrayNoSlice, TransformableDataclassMixin from methodtools import lru_cache from sklearn.neighbors import BallTree import faiss from trimesh import Trimesh from typing import Iterable from typing import Optional, TypeVar import mesh_to_sdf import mesh_to_sdf.scan as sdf_scan import numpy as np import trimesh import trimesh.transformations as T import warnings __doc__ = """ Here are some helper types for data. """ _T = TypeVar("T") class InvalidateLRUOnWriteMixin: def __setattr__(self, key, value): if not key.startswith("__wire|"): for attr in dir(self): if attr.startswith("__wire|"): getattr(self, attr).cache_clear() return super().__setattr__(key, value) def lru_property(func): return lru_cache(maxsize=1)(property(func)) class SingleViewScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True): points_hit : H5ArrayNoSlice # (N, 3) normals_hit : Optional[H5ArrayNoSlice] # (N, 3) points_miss : H5ArrayNoSlice # (M, 3) distances_miss : Optional[H5ArrayNoSlice] # (M) colors_hit : Optional[H5ArrayNoSlice] # (N, 3) colors_miss : Optional[H5ArrayNoSlice] # (M, 3) uv_hits : Optional[H5ArrayNoSlice] # (H, W) dtype=bool uv_miss : Optional[H5ArrayNoSlice] # (H, W) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backfaces) cam_pos : H5ArrayNoSlice # (3) cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4) proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4) transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T: scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487 assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}" out = self if inplace else self.copy(deep=False) out.points_hit = T.transform_points(self.points_hit, mat4) out.normals_hit = T.transform_points(self.normals_hit, mat4) if self.normals_hit is not None else None out.points_miss = T.transform_points(self.points_miss, mat4) out.distances_miss = self.distances_miss * scale_xyz out.cam_pos = T.transform_points(self.points_cam, mat4)[-1] out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None return out def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False) -> _T: assert not self.has_miss_distances if not self.is_hitting: raise ValueError("No hits to compute the ray distance towards") out = self.copy(deep=deep) if copy else self out.distances_miss \ = distance_from_rays_to_point_cloud( ray_origins = out.points_cam, ray_dirs = out.ray_dirs_miss, points = out.points_hit, ).astype(out.points_cam.dtype) return out @lru_property def points(self) -> np.ndarray: # (N+M+1, 3) return np.concatenate(( self.points_hit, self.points_miss, self.points_cam, )) @lru_property def uv_points(self) -> np.ndarray: # (N+M+1, 3) if not self.has_uv: raise ValueError out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.points_hit.dtype) out[self.uv_hits, :] = self.points_hit out[self.uv_miss, :] = self.points_miss return out @lru_property def uv_normals(self) -> np.ndarray: # (N+M+1, 3) if not self.has_uv: raise ValueError out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.normals_hit.dtype) out[self.uv_hits, :] = self.normals_hit return out @lru_property def points_cam(self) -> Optional[np.ndarray]: # (1, 3) if self.cam_pos is None: return None return self.cam_pos[None, :] @lru_property def points_hit_centroid(self) -> np.ndarray: return self.points_hit.mean(axis=0) @lru_property def points_hit_std(self) -> np.ndarray: return self.points_hit.std(axis=0) @lru_property def is_hitting(self) -> bool: return len(self.points_hit) > 0 @lru_property def is_empty(self) -> bool: return not (len(self.points_hit) or len(self.points_miss)) @lru_property def has_colors(self) -> bool: return self.colors_hit is not None or self.colors_miss is not None @lru_property def has_normals(self) -> bool: return self.normals_hit is not None @lru_property def has_uv(self) -> bool: return self.uv_hits is not None @lru_property def has_miss_distances(self) -> bool: return self.distances_miss is not None @lru_property def xyzrgb_hit(self) -> np.ndarray: # (N, 6) if self.colors_hit is None: raise ValueError return np.concatenate([self.points_hit, self.colors_hit], axis=1) @lru_property def xyzrgb_miss(self) -> np.ndarray: # (M, 6) if self.colors_miss is None: raise ValueError return np.concatenate([self.points_miss, self.colors_miss], axis=1) @lru_property def ray_dirs_hit(self) -> np.ndarray: # (N, 3) out = self.points_hit - self.points_cam out /= np.linalg.norm(out, axis=-1)[:, None] # normalize return out @lru_property def ray_dirs_miss(self) -> np.ndarray: # (N, 3) out = self.points_miss - self.points_cam out /= np.linalg.norm(out, axis=-1)[:, None] # normalize return out @classmethod def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewScan": if "phi" not in kw and not "theta" in kw: kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0] scan = sample_single_view_scan_from_mesh(mesh, **kw) if compute_miss_distances and scan.is_hitting: scan.compute_miss_distances() return scan def to_uv_scan(self) -> "SingleViewUVScan": return SingleViewUVScan.from_scan(self) @classmethod def from_uv_scan(self, uvscan: "SingleViewUVScan") -> "SingleViewUVScan": return uvscan.to_scan() # The same, but with support for pagination (should have been this way since the start...) class SingleViewUVScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True): # B may be (N) or (H, W), the latter may be flattened hits : H5Array # (*B) dtype=bool miss : H5Array # (*B) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backface hits) points : H5Array # (*B, 3) on far plane if miss, NaN if neither hit or miss normals : Optional[H5Array] # (*B, 3) NaN if not hit colors : Optional[H5Array] # (*B, 3) distances : Optional[H5Array] # (*B) NaN if not miss cam_pos : Optional[H5ArrayNoSlice] # (3) or (*B, 3) cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4) proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4) transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices @classmethod def from_scan(cls, scan: SingleViewScan): if not scan.has_uv: raise ValueError("Scan cloud has no UV data") hits, miss = scan.uv_hits, scan.uv_miss dtype = scan.points_hit.dtype assert hits.ndim in (1, 2), hits.ndim assert hits.shape == miss.shape, (hits.shape, miss.shape) points = np.full((*hits.shape, 3), np.nan, dtype=dtype) points[hits, :] = scan.points_hit points[miss, :] = scan.points_miss normals = None if scan.has_normals: normals = np.full((*hits.shape, 3), np.nan, dtype=dtype) normals[hits, :] = scan.normals_hit distances = None if scan.has_miss_distances: distances = np.full(hits.shape, np.nan, dtype=dtype) distances[miss] = scan.distances_miss colors = None if scan.has_colors: colors = np.full((*hits.shape, 3), np.nan, dtype=dtype) if scan.colors_hit is not None: colors[hits, :] = scan.colors_hit if scan.colors_miss is not None: colors[miss, :] = scan.colors_miss return cls( hits = hits, miss = miss, points = points, normals = normals, colors = colors, distances = distances, cam_pos = scan.cam_pos, cam_mat4 = scan.cam_mat4, proj_mat4 = scan.proj_mat4, transforms = scan.transforms, ) def to_scan(self) -> "SingleViewScan": if not self.is_single_view: raise ValueError return SingleViewScan( points_hit = self.points [self.hits, :], points_miss = self.points [self.miss, :], normals_hit = self.normals [self.hits, :] if self.has_normals else None, distances_miss = self.distances[self.miss] if self.has_miss_distances else None, colors_hit = self.colors [self.hits, :] if self.has_colors else None, colors_miss = self.colors [self.miss, :] if self.has_colors else None, uv_hits = self.hits, uv_miss = self.miss, cam_pos = self.cam_pos, cam_mat4 = self.cam_mat4, proj_mat4 = self.proj_mat4, transforms = self.transforms, ) def to_mesh(self) -> trimesh.Trimesh: faces: list[(tuple[int, int],)*3] = [] for x in range(self.hits.shape[0]-1): for y in range(self.hits.shape[1]-1): c11 = x, y c12 = x, y+1 c22 = x+1, y+1 c21 = x+1, y n = sum(map(self.hits.__getitem__, (c11, c12, c22, c21))) if n == 3: faces.append((*filter(self.hits.__getitem__, (c11, c12, c22, c21)),)) elif n == 4: faces.append((c11, c12, c22)) faces.append((c11, c22, c21)) xy2idx = {c:i for i, c in enumerate(set(k for j in faces for k in j))} assert self.colors is not None return trimesh.Trimesh( vertices = [self.points[i] for i in xy2idx.keys()], vertex_colors = [self.colors[i] for i in xy2idx.keys()] if self.colors is not None else None, faces = [tuple(xy2idx[i] for i in face) for face in faces], ) def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T: scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487 assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}" unflat = self.hits.shape flat = np.product(unflat) out = self if inplace else self.copy(deep=False) out.points = T.transform_points(self.points .reshape((*flat, 3)), mat4).reshape((*unflat, 3)) out.normals = T.transform_points(self.normals.reshape((*flat, 3)), mat4).reshape((*unflat, 3)) if self.normals_hit is not None else None out.distances = self.distances_miss * scale_xyz out.cam_pos = T.transform_points(self.cam_pos[None, ...], mat4)[0] out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None return out def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False, surface_points: Optional[np.ndarray] = None) -> _T: assert not self.has_miss_distances shape = self.hits.shape out = self.copy(deep=deep) if copy else self out.distances = np.zeros(shape, dtype=self.points.dtype) if self.is_hitting: out.distances[self.miss] \ = distance_from_rays_to_point_cloud( ray_origins = self.cam_pos_unsqueezed_miss, ray_dirs = self.ray_dirs_miss, points = surface_points if surface_points is not None else self.points[self.hits], ) return out def fill_missing_points(self: _T, *, copy: bool = False, deep: bool = False) -> _T: """ Fill in missing points as hitting the far plane. """ if not self.is_2d: raise ValueError("Cannot fill missing points for non-2d scan!") if not self.is_single_view: raise ValueError("Cannot fill missing points for non-single-view scans!") if self.cam_mat4 is None: raise ValueError("cam_mat4 is None") if self.proj_mat4 is None: raise ValueError("proj_mat4 is None") uv = np.argwhere(self.missing).astype(self.points.dtype) uv[:, 0] /= (self.missing.shape[1] - 1) / 2 uv[:, 1] /= (self.missing.shape[0] - 1) / 2 uv -= 1 uv = np.stack(( uv[:, 1], -uv[:, 0], np.ones(uv.shape[0]), # far clipping plane np.ones(uv.shape[0]), # homogeneous coordinate ), axis=-1) uv = uv @ (self.cam_mat4 @ np.linalg.inv(self.proj_mat4)).T out = self.copy(deep=deep) if copy else self out.points[self.missing, :] = uv[:, :3] / uv[:, 3][:, None] return out @lru_property def is_hitting(self) -> bool: return np.any(self.hits) @lru_property def has_colors(self) -> bool: return not self.colors is None @lru_property def has_normals(self) -> bool: return not self.normals is None @lru_property def has_miss_distances(self) -> bool: return not self.distances is None @lru_property def any_missing(self) -> bool: return np.any(self.missing) @lru_property def has_missing(self) -> bool: return self.any_missing and not np.any(np.isnan(self.points[self.missing])) @lru_property def cam_pos_unsqueezed(self) -> H5Array: if self.cam_pos.ndim != 1: return self.cam_pos else: cam_pos = self.cam_pos for _ in range(self.hits.ndim): cam_pos = cam_pos[None, ...] return cam_pos @lru_property def cam_pos_unsqueezed_hit(self) -> H5Array: if self.cam_pos.ndim != 1: return self.cam_pos[self.hits, :] else: return self.cam_pos[None, :] @lru_property def cam_pos_unsqueezed_miss(self) -> H5Array: if self.cam_pos.ndim != 1: return self.cam_pos[self.miss, :] else: return self.cam_pos[None, :] @lru_property def ray_dirs(self) -> H5Array: return (self.points - self.cam_pos_unsqueezed) * (1 / self.depths[..., None]) @lru_property def ray_dirs_hit(self) -> H5Array: out = self.points[self.hits, :] - self.cam_pos_unsqueezed_hit out /= np.linalg.norm(out, axis=-1)[..., None] # normalize return out @lru_property def ray_dirs_miss(self) -> H5Array: out = self.points[self.miss, :] - self.cam_pos_unsqueezed_miss out /= np.linalg.norm(out, axis=-1)[..., None] # normalize return out @lru_property def depths(self) -> H5Array: return np.linalg.norm(self.points - self.cam_pos_unsqueezed, axis=-1) @lru_property def missing(self) -> H5Array: return ~(self.hits | self.miss) @classmethod def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan": if "phi" not in kw and not "theta" in kw: kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0] scan = sample_single_view_scan_from_mesh(mesh, **kw).to_uv_scan() if compute_miss_distances: scan.compute_miss_distances() assert scan.is_2d return scan @classmethod def from_mesh_sphere_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan": scan = sample_sphere_view_scan_from_mesh(mesh, **kw) if compute_miss_distances: surface_points = None if scan.hits.sum() > mesh.vertices.shape[0]: surface_points = mesh.vertices.astype(scan.points.dtype) if not kw.get("no_unit_sphere", False): translation, scale = compute_unit_sphere_transform(mesh, dtype=scan.points.dtype) surface_points = (surface_points + translation) * scale scan.compute_miss_distances(surface_points=surface_points) assert scan.is_flat return scan def flatten_and_permute_(self: _T, copy=False) -> _T: # inplace by default n_items = np.product(self.hits.shape) permutation = np.random.permutation(n_items) out = self.copy(deep=False) if copy else self out.hits = out.hits .reshape((n_items, ))[permutation] out.miss = out.miss .reshape((n_items, ))[permutation] out.points = out.points .reshape((n_items, 3))[permutation, :] out.normals = out.normals .reshape((n_items, 3))[permutation, :] if out.has_normals else None out.colors = out.colors .reshape((n_items, 3))[permutation, :] if out.has_colors else None out.distances = out.distances.reshape((n_items, ))[permutation] if out.has_miss_distances else None return out @property def is_single_view(self) -> bool: return np.product(self.cam_pos.shape[:-1]) == 1 if not self.cam_pos is None else True @property def is_flat(self) -> bool: return len(self.hits.shape) == 1 @property def is_2d(self) -> bool: return len(self.hits.shape) == 2 # transforms can be found in pytorch3d.transforms and in open3d # and in trimesh.transformations def sample_single_view_scans_from_mesh( mesh : Trimesh, *, n_batches : int, scan_resolution : int = 400, compute_normals : bool = False, fov : float = 1.0472, # 60 degrees in radians, vertical field of view. camera_distance : float = 2, no_filter_backhits : bool = False, ) -> Iterable[SingleViewScan]: normalized_mesh_cache = [] for _ in range(n_batches): theta, phi = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0] yield sample_single_view_scan_from_mesh( mesh = mesh, phi = phi, theta = theta, _mesh_is_normalized = False, scan_resolution = scan_resolution, compute_normals = compute_normals, fov = fov, camera_distance = camera_distance, no_filter_backhits = no_filter_backhits, _mesh_cache = normalized_mesh_cache, ) def sample_single_view_scan_from_mesh( mesh : Trimesh, *, phi : float, theta : float, scan_resolution : int = 200, compute_normals : bool = False, fov : float = 1.0472, # 60 degrees in radians, vertical field of view. camera_distance : float = 2, no_filter_backhits : bool = False, no_unit_sphere : bool = False, dtype : type = np.float32, _mesh_cache : Optional[list] = None, # provide a list if mesh is reused ) -> SingleViewScan: # scale and center to unit sphere is_cache = isinstance(_mesh_cache, list) if is_cache and _mesh_cache and _mesh_cache[0] is mesh: _, mesh, translation, scale = _mesh_cache else: if is_cache: if _mesh_cache: _mesh_cache.clear() _mesh_cache.append(mesh) translation, scale = compute_unit_sphere_transform(mesh) mesh = mesh_to_sdf.scale_to_unit_sphere(mesh) if is_cache: _mesh_cache.extend((mesh, translation, scale)) z_near = 1 z_far = 3 cam_mat4 = sdf_scan.get_camera_transform_looking_at_origin(phi, theta, camera_distance=camera_distance) cam_pos = cam_mat4 @ np.array([0, 0, 0, 1]) scan = sdf_scan.Scan(mesh, camera_transform = cam_mat4, resolution = scan_resolution, calculate_normals = compute_normals, fov = fov, z_near = z_near, z_far = z_far, no_flip_backfaced_normals = True ) # all the scan rays that hit the far plane, based on sdf_scan.Scan.__init__ misses = np.argwhere(scan.depth_buffer == 0) points_miss = np.ones((misses.shape[0], 4)) points_miss[:, [1, 0]] = misses.astype(float) / (scan_resolution -1) * 2 - 1 points_miss[:, 1] *= -1 points_miss[:, 2] = 1 # far plane in clipping space points_miss = points_miss @ (cam_mat4 @ np.linalg.inv(scan.projection_matrix)).T points_miss /= points_miss[:, 3][:, np.newaxis] points_miss = points_miss[:, :3] uv_hits = scan.depth_buffer != 0 uv_miss = ~uv_hits if not no_filter_backhits: if not compute_normals: raise ValueError("not `no_filter_backhits` requires `compute_normals`") # inner product mask = np.einsum('ij,ij->i', scan.points - cam_pos[:3][None, :], scan.normals) < 0 scan.points = scan.points [mask, :] scan.normals = scan.normals[mask, :] uv_hits[uv_hits] = mask transforms = {} # undo unit-sphere transform if no_unit_sphere: scan.points = scan.points * (1 / scale) - translation points_miss = points_miss * (1 / scale) - translation cam_pos[:3] = cam_pos[:3] * (1 / scale) - translation cam_mat4[:3, :] *= 1 / scale cam_mat4[:3, 3] -= translation transforms["unit_sphere"] = T.scale_and_translate(scale=scale, translate=translation) transforms["model"] = np.eye(4) else: transforms["model"] = np.linalg.inv(T.scale_and_translate(scale=scale, translate=translation)) transforms["unit_sphere"] = np.eye(4) return SingleViewScan( normals_hit = scan.normals .astype(dtype), points_hit = scan.points .astype(dtype), points_miss = points_miss .astype(dtype), distances_miss = None, colors_hit = None, colors_miss = None, uv_hits = uv_hits .astype(bool), uv_miss = uv_miss .astype(bool), cam_pos = cam_pos[:3] .astype(dtype), cam_mat4 = cam_mat4 .astype(dtype), proj_mat4 = scan.projection_matrix .astype(dtype), transforms = {k:v.astype(dtype) for k, v in transforms.items()}, ) def sample_sphere_view_scan_from_mesh( mesh : Trimesh, *, sphere_points : int = 4000, # resulting rays are n*(n-1) compute_normals : bool = False, no_filter_backhits : bool = False, no_unit_sphere : bool = False, no_permute : bool = False, dtype : type = np.float32, **kw, ) -> SingleViewUVScan: translation, scale = compute_unit_sphere_transform(mesh, dtype=dtype) # get unit-sphere points, then transform to model space two_sphere = generate_equidistant_sphere_rays(sphere_points, **kw).astype(dtype) # (n*(n-1), 2, 3) two_sphere = two_sphere / scale - translation # we transform after cache lookup if mesh.ray.__class__.__module__.split(".")[-1] != "ray_pyembree": warnings.warn("Pyembree not found, the ray-tracing will be SLOW!") ( locations, index_ray, index_tri, ) = mesh.ray.intersects_location( two_sphere[:, 0, :], two_sphere[:, 1, :] - two_sphere[:, 0, :], # direction, not target coordinate multiple_hits=False, ) if compute_normals: location_normals = mesh.face_normals[index_tri] batch = two_sphere.shape[:1] hits = np.zeros((*batch,), dtype=np.bool) miss = np.ones((*batch,), dtype=np.bool) cam_pos = two_sphere[:, 0, :] intersections = two_sphere[:, 1, :] # far-plane, effectively normals = np.zeros((*batch, 3), dtype=dtype) index_ray_front = index_ray if not no_filter_backhits: if not compute_normals: raise ValueError("not `no_filter_backhits` requires `compute_normals`") mask = ((intersections[index_ray] - cam_pos[index_ray]) * location_normals).sum(axis=-1) <= 0 index_ray_front = index_ray[mask] hits[index_ray_front] = True miss[index_ray] = False intersections[index_ray] = locations normals[index_ray] = location_normals if not no_permute: assert len(batch) == 1, batch permutation = np.random.permutation(*batch) hits = hits [permutation] miss = miss [permutation] intersections = intersections[permutation, :] normals = normals [permutation, :] cam_pos = cam_pos [permutation, :] # apply unit sphere transform if not no_unit_sphere: intersections = (intersections + translation) * scale cam_pos = (cam_pos + translation) * scale return SingleViewUVScan( hits = hits, miss = miss, points = intersections, normals = normals, colors = None, # colors distances = None, cam_pos = cam_pos, cam_mat4 = None, proj_mat4 = None, transforms = {}, ) def distance_from_rays_to_point_cloud( ray_origins : np.ndarray, # (*A, 3) ray_dirs : np.ndarray, # (*A, 3) points : np.ndarray, # (*B, 3) dirs_normalized : bool = False, n_steps : int = 40, ) -> np.ndarray: # (A) # anything outside of this volume will never constribute to the result max_norm = max( np.linalg.norm(ray_origins, axis=-1).max(), np.linalg.norm(points, axis=-1).max(), ) * 1.02 if not dirs_normalized: ray_dirs = ray_dirs / np.linalg.norm(ray_dirs, axis=-1)[..., None] # deal with single-view clouds if ray_origins.shape != ray_dirs.shape: ray_origins = np.broadcast_to(ray_origins, ray_dirs.shape) n_points = np.product(points.shape[:-1]) use_faiss = n_points > 160000*4 if not use_faiss: index = BallTree(points) else: # http://ann-benchmarks.com/index.html assert np.issubdtype(points.dtype, np.float32) assert np.issubdtype(ray_origins.dtype, np.float32) assert np.issubdtype(ray_dirs.dtype, np.float32) index = faiss.index_factory(points.shape[-1], "NSG32,Flat") # https://github.com/facebookresearch/faiss/wiki/The-index-factory index.nprobe = 5 # 10 # default is 1 index.train(points) index.add(points) if not use_faiss: min_d, min_n = index.query(ray_origins, k=1, return_distance=True) else: min_d, min_n = index.search(ray_origins, k=1) min_d = np.sqrt(min_d) acc_d = min_d.copy() for step in range(1, n_steps+1): query_points = ray_origins + acc_d * ray_dirs if max_norm is not None: qmask = np.linalg.norm(query_points, axis=-1) < max_norm if not qmask.any(): break query_points = query_points[qmask] else: qmask = slice(None) if not use_faiss: current_d, current_n = index.query(query_points, k=1, return_distance=True) else: current_d, current_n = index.search(query_points, k=1) current_d = np.sqrt(current_d) if max_norm is not None: min_d[qmask] = np.minimum(current_d, min_d[qmask]) new_min_mask = min_d[qmask] == current_d qmask2 = qmask.copy() qmask2[qmask2] = new_min_mask[..., 0] min_n[qmask2] = current_n[new_min_mask[..., 0]] acc_d[qmask] += current_d * 0.25 else: np.minimum(current_d, min_d, out=min_d) new_min_mask = min_d == current_d min_n[new_min_mask] = current_n[new_min_mask] acc_d += current_d * 0.25 closest_points = points[min_n[:, 0], :] # k=1 distances = np.linalg.norm(np.cross(closest_points - ray_origins, ray_dirs, axis=-1), axis=-1) return distances # helpers @compose(np.array) # make copy to avoid lru cache mutation @lru_cache(maxsize=1) def generate_equidistant_sphere_rays(n : int, **kw) -> np.ndarray: # output (n*n(-1)) rays, n may be off sphere_points = points.generate_equidistant_sphere_points(n=n, **kw) indices = np.indices((len(sphere_points),))[0] # (N) # cartesian product cprod = np.transpose([np.tile(indices, len(indices)), np.repeat(indices, len(indices))]) # (N**2, 2) # filter repeated combinations permutations = cprod[cprod[:, 0] != cprod[:, 1], :] # (N*(N-1), 2) # lookup sphere points two_sphere = sphere_points[permutations, :] # (N*(N-1), 2, 3) return two_sphere def compute_unit_sphere_transform(mesh: Trimesh, *, dtype=type) -> tuple[np.ndarray, float]: """ returns translation and scale which mesh_to_sdf applies to meshes before computing their SDF cloud """ # the transformation applied by mesh_to_sdf.scale_to_unit_sphere(mesh) translation = -mesh.bounding_box.centroid scale = 1 / np.max(np.linalg.norm(mesh.vertices + translation, axis=1)) if dtype is not None: translation = translation.astype(dtype) scale = scale .astype(dtype) return translation, scale