This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

View File

View File

@@ -0,0 +1,90 @@
from ...utils.helpers import make_relative
from pathlib import Path
from tqdm import tqdm
from typing import Union, Optional
import io
import os
import json
import requests
PathLike = Union[os.PathLike, str]
__doc__ = """
Here are some helper functions for processing data.
"""
def check_url(url): # HTTP HEAD
return requests.head(url).ok
def download_stream(
url : str,
file_object,
block_size : int = 1024,
silent : bool = False,
label : Optional[str] = None,
):
resp = requests.get(url, stream=True)
total_size = int(resp.headers.get("content-length", 0))
if not silent:
progress_bar = tqdm(total=total_size , unit="iB", unit_scale=True, desc=label)
for chunk in resp.iter_content(block_size):
if not silent:
progress_bar.update(len(chunk))
file_object.write(chunk)
if not silent:
progress_bar.close()
if total_size != 0 and progress_bar.n != total_size:
print("ERROR, something went wrong")
def download_data(
url : str,
block_size : int = 1024,
silent : bool = False,
label : Optional[str] = None,
) -> bytearray:
f = io.BytesIO()
download_stream(url, f, block_size=block_size, silent=silent, label=label)
f.seek(0)
return bytearray(f.read())
def download_file(
url : str,
fname : Union[Path, str],
block_size : int = 1024,
silent = False,
):
if not isinstance(fname, Path):
fname = Path(fname)
with fname.open("wb") as f:
download_stream(url, f, block_size=block_size, silent=silent, label=make_relative(fname, Path.cwd()).name)
def is_downloaded(
target_dir : PathLike,
url : str,
*,
add : bool = False,
dbfiles : Union[list[PathLike], PathLike],
):
if not isinstance(target_dir, os.PathLike):
target_dir = Path(target_dir)
if not isinstance(dbfiles, list):
dbfiles = [dbfiles]
if not dbfiles:
raise ValueError("'dbfiles' empty")
downloaded = set()
for dbfile_fname in dbfiles:
dbfile_fname = target_dir / dbfile_fname
if dbfile_fname.is_file():
with open(dbfile_fname, "r") as f:
downloaded.update(json.load(f)["downloaded"])
if add and url not in downloaded:
downloaded.add(url)
with open(dbfiles[0], "w") as f:
data = {"downloaded": sorted(downloaded)}
json.dump(data, f, indent=2, sort_keys=True)
return True
return url in downloaded

View File

@@ -0,0 +1,370 @@
#!/usr/bin/env python3
from abc import abstractmethod, ABCMeta
from collections import namedtuple
from pathlib import Path
import copy
import dataclasses
import functools
import h5py as h5
import hdf5plugin
import numpy as np
import operator
import os
import sys
import typing
__all__ = [
"DataclassMeta",
"Dataclass",
"H5Dataclass",
"H5Array",
"H5ArrayNoSlice",
]
T = typing.TypeVar("T")
NoneType = type(None)
PathLike = typing.Union[os.PathLike, str]
H5Array = typing._alias(np.ndarray, 0, inst=False, name="H5Array")
H5ArrayNoSlice = typing._alias(np.ndarray, 0, inst=False, name="H5ArrayNoSlice")
DataclassField = namedtuple("DataclassField", [
"name",
"type",
"is_optional",
"is_array",
"is_sliceable",
"is_prefix",
])
def strip_optional(val: type) -> type:
if typing.get_origin(val) is typing.Union:
union = set(typing.get_args(val))
if len(union - {NoneType}) == 1:
val, = union - {NoneType}
else:
raise TypeError(f"Non-'typing.Optional' 'typing.Union' is not supported: {typing._type_repr(val)!r}")
return val
def is_array(val, *, _inner=False):
"""
Hacky way to check if a value or type is an array.
The hack omits having to depend on large frameworks such as pytorch or pandas
"""
val = strip_optional(val)
if val is H5Array or val is H5ArrayNoSlice:
return True
if typing._type_repr(val) in (
"numpy.ndarray",
"torch.Tensor",
):
return True
if not _inner:
return is_array(type(val), _inner=True)
return False
def prod(numbers: typing.Iterable[T], initial: typing.Optional[T] = None) -> T:
if initial is not None:
return functools.reduce(operator.mul, numbers, initial)
else:
return functools.reduce(operator.mul, numbers)
class DataclassMeta(type):
def __new__(
mcls,
name : str,
bases : tuple[type, ...],
attrs : dict[str, typing.Any],
**kwargs,
):
cls = super().__new__(mcls, name, bases, attrs, **kwargs)
if sys.version_info[:2] >= (3, 10) and not hasattr(cls, "__slots__"):
cls = dataclasses.dataclass(slots=True)(cls)
else:
cls = dataclasses.dataclass(cls)
return cls
class DataclassABCMeta(DataclassMeta, ABCMeta):
pass
class Dataclass(metaclass=DataclassMeta):
def __getitem__(self, key: str) -> typing.Any:
if key in self.keys():
return getattr(self, key)
raise KeyError(key)
def __setitem__(self, key: str, value: typing.Any):
if key in self.keys():
return setattr(self, key, value)
raise KeyError(key)
def keys(self) -> typing.KeysView:
return self.as_dict().keys()
def values(self) -> typing.ValuesView:
return self.as_dict().values()
def items(self) -> typing.ItemsView:
return self.as_dict().items()
def as_dict(self, properties_to_include: set[str] = None, **kw) -> dict[str, typing.Any]:
out = dataclasses.asdict(self, **kw)
for name in (properties_to_include or []):
out[name] = getattr(self, name)
return out
def as_tuple(self, properties_to_include: list[str]) -> tuple:
out = dataclasses.astuple(self)
if not properties_to_include:
return out
else:
return (
*out,
*(getattr(self, name) for name in properties_to_include),
)
def copy(self: T, *, deep=True) -> T:
return (copy.deepcopy if deep else copy.copy)(self)
class H5Dataclass(Dataclass):
# settable with class params:
_prefix : str = dataclasses.field(init=False, repr=False, default="")
_n_pages : int = dataclasses.field(init=False, repr=False, default=10)
_require_all : bool = dataclasses.field(init=False, repr=False, default=False)
def __init_subclass__(cls,
prefix : typing.Optional[str] = None,
n_pages : typing.Optional[int] = None,
require_all : typing.Optional[bool] = None,
**kw,
):
super().__init_subclass__(**kw)
assert dataclasses.is_dataclass(cls)
if prefix is not None: cls._prefix = prefix
if n_pages is not None: cls._n_pages = n_pages
if require_all is not None: cls._require_all = require_all
@classmethod
def _get_fields(cls) -> typing.Iterable[DataclassField]:
for field in dataclasses.fields(cls):
if not field.init:
continue
assert field.name not in ("_prefix", "_n_pages", "_require_all"), (
f"{field.name!r} can not be in {cls.__qualname__}.__init__.\n"
"Set it with dataclasses.field(default=YOUR_VALUE, init=False, repr=False)"
)
if isinstance(field.type, str):
raise TypeError("Type hints are strings, perhaps avoid using `from __future__ import annotations`")
type_inner = strip_optional(field.type)
is_prefix = typing.get_origin(type_inner) is dict and typing.get_args(type_inner)[:1] == (str,)
field_type = typing.get_args(type_inner)[1] if is_prefix else field.type
if field.default is None or typing.get_origin(field.type) is typing.Union and NoneType in typing.get_args(field.type):
field_type = typing.Optional[field_type]
yield DataclassField(
name = field.name,
type = strip_optional(field_type),
is_optional = typing.get_origin(field_type) is typing.Union and NoneType in typing.get_args(field_type),
is_array = is_array(field_type),
is_sliceable = is_array(field_type) and strip_optional(field_type) is not H5ArrayNoSlice,
is_prefix = is_prefix,
)
@classmethod
def from_h5_file(cls : type[T],
fname : typing.Union[PathLike, str],
*,
page : typing.Optional[int] = None,
n_pages : typing.Optional[int] = None,
read_slice : slice = slice(None),
require_even_pages : bool = True,
) -> T:
if not isinstance(fname, Path):
fname = Path(fname)
if n_pages is None:
n_pages = cls._n_pages
if not fname.exists():
raise FileNotFoundError(str(fname))
if not h5.is_hdf5(fname):
raise TypeError(f"Not a HDF5 file: {str(fname)!r}")
# if this class has no fields, print a example class:
if not any(field.init for field in dataclasses.fields(cls)):
with h5.File(fname, "r") as f:
klen = max(map(len, f.keys()))
example_cls = f"\nclass {cls.__name__}(Dataclass, require_all=True):\n" + "\n".join(
f" {k.ljust(klen)} : "
+ (
"H5Array" if prod(v.shape, 1) > 1 else (
"float" if issubclass(v.dtype.type, np.floating) else (
"int" if issubclass(v.dtype.type, np.integer) else (
"bool" if issubclass(v.dtype.type, np.bool_) else (
"typing.Any"
))))).ljust(14 + 1)
+ f" #{repr(v).split(':', 1)[1].removesuffix('>')}"
for k, v in f.items()
)
raise NotImplementedError(f"{cls!r} has no fields!\nPerhaps try the following:{example_cls}")
fields_consumed = set()
def make_kwarg(
file : h5.File,
keys : typing.KeysView,
field : DataclassField,
) -> tuple[str, typing.Any]:
if field.is_optional:
if field.name not in keys:
return field.name, None
if field.is_sliceable:
if page is not None:
n_items = int(f[cls._prefix + field.name].shape[0])
page_len = n_items // n_pages
modulus = n_items % n_pages
if modulus: page_len += 1 # round up
if require_even_pages and modulus:
raise ValueError(f"Field {field.name!r} {tuple(f[cls._prefix + field.name].shape)} is not cleanly divisible into {n_pages} pages")
this_slice = slice(
start = page_len * page,
stop = page_len * (page+1),
step = read_slice.step, # inherit step
)
else:
this_slice = read_slice
else:
this_slice = slice(None) # read all
# array or scalar?
def read_dataset(var):
# https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
if field.is_array:
return var[this_slice]
if var.shape == (1,):
return var[0]
else:
return var[()]
if field.is_prefix:
fields_consumed.update(
key
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
)
return field.name, {
key.removeprefix(f"{cls._prefix}{field.name}_") : read_dataset(file[key])
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
}
else:
fields_consumed.add(cls._prefix + field.name)
return field.name, read_dataset(file[cls._prefix + field.name])
with h5.File(fname, "r") as f:
keys = f.keys()
init_dict = dict( make_kwarg(f, keys, i) for i in cls._get_fields() )
try:
out = cls(**init_dict)
except Exception as e:
class_attrs = set(field.name for field in dataclasses.fields(cls))
file_attr = set(init_dict.keys())
raise e.__class__(f"{e}. {class_attrs=}, {file_attr=}, diff={class_attrs.symmetric_difference(file_attr)}") from e
if cls._require_all:
fields_not_consumed = set(keys) - fields_consumed
if fields_not_consumed:
raise ValueError(f"Not all HDF5 fields consumed: {fields_not_consumed!r}")
return out
def to_h5_file(self,
fname : PathLike,
mkdir : bool = False,
):
if not isinstance(fname, Path):
fname = Path(fname)
if not fname.parent.is_dir():
if mkdir:
fname.parent.mkdir(parents=True)
else:
raise NotADirectoryError(fname.parent)
with h5.File(fname, "w") as f:
for field in type(self)._get_fields():
if field.is_optional and getattr(self, field.name) is None:
continue
value = getattr(self, field.name)
if field.is_array:
if any(type(i) is not np.ndarray for i in (value.values() if field.is_prefix else [value])):
raise TypeError(
"When dumping a H5Dataclass, make sure the array fields are "
f"numpy arrays (the type of {field.name!r} is {typing._type_repr(type(value))}).\n"
"Example: h5dataclass.map_arrays(torch.Tensor.numpy)"
)
else:
pass
def write_value(key: str, value: typing.Any):
if field.is_array:
f.create_dataset(key, data=value, **hdf5plugin.LZ4())
else:
f.create_dataset(key, data=value)
if field.is_prefix:
for k, v in value.items():
write_value(self._prefix + field.name + "_" + k, v)
else:
write_value(self._prefix + field.name, value)
def map_arrays(self: T, func: typing.Callable[[H5Array], H5Array], do_copy: bool = False) -> T:
if do_copy: # shallow
self = self.copy(deep=False)
for field in type(self)._get_fields():
if field.is_optional and getattr(self, field.name) is None:
continue
if field.is_prefix and field.is_array:
setattr(self, field.name, {
k : func(v)
for k, v in getattr(self, field.name).items()
})
elif field.is_array:
setattr(self, field.name, func(getattr(self, field.name)))
return self
def astype(self: T, t: type, do_copy: bool = False, convert_nonfloats: bool = False) -> T:
return self.map_arrays(lambda x: x.astype(t) if convert_nonfloats or not np.issubdtype(x.dtype, int) else x)
def copy(self: T, *, deep=True) -> T:
out = super().copy(deep=deep)
if not deep:
for field in type(self)._get_fields():
if field.is_prefix:
out[field.name] = copy.copy(field.name)
return out
@property
def shape(self) -> dict[str, tuple[int, ...]]:
return {
key: value.shape
for key, value in self.items()
if hasattr(value, "shape")
}
class TransformableDataclassMixin(metaclass=DataclassABCMeta):
@abstractmethod
def transform(self: T, mat4: np.ndarray, inplace=False) -> T:
...
def transform_to(self: T, name: str, inverse_name: str = None, *, inplace=False) -> T:
mtx = self.transforms[name]
out = self.transform(mtx, inplace=inplace)
out.transforms.pop(name) # consumed
inv = np.linalg.inv(mtx)
for key in list(out.transforms.keys()): # maintain the other transforms
out.transforms[key] = out.transforms[key] @ inv
if inverse_name is not None: # store inverse
out.transforms[inverse_name] = inv
return out

View File

@@ -0,0 +1,48 @@
from math import pi
from trimesh import Trimesh
import numpy as np
import os
import trimesh
import trimesh.transformations as T
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
__doc__ = """
Here are some helper functions for processing data.
"""
def rotate_to_closest_axis_aligned_bounds(
mesh : Trimesh,
order_axes : bool = True,
fail_ok : bool = True,
) -> np.ndarray:
to_origin_mat4, extents = trimesh.bounds.oriented_bounds(mesh, ordered=not order_axes)
to_aabb_rot_mat4 = T.euler_matrix(*T.decompose_matrix(to_origin_mat4)[3])
if not order_axes:
return to_aabb_rot_mat4
v = pi / 4 * 1.01 # tolerance
v2 = pi / 2
faces = (
(0, 0),
(1, 0),
(2, 0),
(3, 0),
(0, 1),
(0,-1),
)
orientations = [ # 6 faces x 4 rotations per face
(f[0] * v2, f[1] * v2, i * v2)
for i in range(4)
for f in faces]
for x, y, z in orientations:
mat4 = T.euler_matrix(x, y, z) @ to_aabb_rot_mat4
ai, aj, ak = T.euler_from_matrix(mat4)
if abs(ai) <= v and abs(aj) <= v and abs(ak) <= v:
return mat4
if fail_ok: return to_aabb_rot_mat4
raise Exception("Unable to orient mesh")

View File

@@ -0,0 +1,297 @@
from __future__ import annotations
from ...utils.helpers import compose
from functools import reduce, lru_cache
from math import ceil
from typing import Iterable
import numpy as np
import operator
__doc__ = """
Here are some helper functions for processing data.
"""
def img2col(img: np.ndarray, psize: int) -> np.ndarray:
# based of ycb_generate_point_cloud.py provided by YCB
n_channels = 1 if len(img.shape) == 2 else img.shape[0]
n_channels, rows, cols = (1,) * (3 - len(img.shape)) + img.shape
# pad the image
img_pad = np.zeros((
n_channels,
int(ceil(1.0 * rows / psize) * psize),
int(ceil(1.0 * cols / psize) * psize),
))
img_pad[:, 0:rows, 0:cols] = img
# allocate output buffer
final = np.zeros((
img_pad.shape[1],
img_pad.shape[2],
n_channels,
psize,
psize,
))
for c in range(n_channels):
for x in range(psize):
for y in range(psize):
img_shift = np.vstack((
img_pad[c, x:],
img_pad[c, :x]))
img_shift = np.column_stack((
img_shift[:, y:],
img_shift[:, :y]))
final[x::psize, y::psize, c] = np.swapaxes(
img_shift.reshape(
int(img_pad.shape[1] / psize), psize,
int(img_pad.shape[2] / psize), psize),
1,
2)
# crop output and unwrap axes with size==1
return np.squeeze(final[
0:rows - psize + 1,
0:cols - psize + 1])
def filter_depth_discontinuities(depth_map: np.ndarray, filt_size = 7, thresh = 1000) -> np.ndarray:
"""
Removes data close to discontinuities, with size filt_size.
"""
# based of ycb_generate_point_cloud.py provided by YCB
# Ensure that filter sizes are okay
assert filt_size % 2, "Can only use odd filter sizes."
# Compute discontinuities
offset = int(filt_size - 1) // 2
patches = 1.0 * img2col(depth_map, filt_size)
mids = patches[:, :, offset, offset]
mins = np.min(patches, axis=(2, 3))
maxes = np.max(patches, axis=(2, 3))
discont = np.maximum(
np.abs(mins - mids),
np.abs(maxes - mids))
mark = discont > thresh
# Account for offsets
final_mark = np.zeros(depth_map.shape, dtype=np.uint16)
final_mark[offset:offset + mark.shape[0],
offset:offset + mark.shape[1]] = mark
return depth_map * (1 - final_mark)
def reorient_depth_map(
depth_map : np.ndarray,
rgb_map : np.ndarray,
depth_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
depth_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
rgb_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
rgb_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
ir_to_rgb_mat4 : np.ndarray, # extrinsic transformation matrix from depth to rgb camera viewpoint
rgb_mask_map : np.ndarray = None,
_output_points = False, # retval (H, W) if false else (N, XYZRGB)
_output_hits_uvs = False, # retval[1] is dtype=bool of hits shaped like depth_map
) -> np.ndarray:
"""
Corrects depth_map to be from the same view as the rgb_map, with the same dimensions.
If _output_points is True, the points returned are in the rgb camera space.
"""
# based of ycb_generate_point_cloud.py provided by YCB
# now faster AND more easy on the GIL
height_old, width_old, *_ = depth_map.shape
height, width, *_ = rgb_map.shape
d_cx, r_cx = depth_mat3[0, 2], rgb_mat3[0, 2] # optical center
d_cy, r_cy = depth_mat3[1, 2], rgb_mat3[1, 2]
d_fx, r_fx = depth_mat3[0, 0], rgb_mat3[0, 0] # focal length
d_fy, r_fy = depth_mat3[1, 1], rgb_mat3[1, 1]
d_k1, d_k2, d_p1, d_p2, d_k3 = depth_vec5
c_k1, c_k2, c_p1, c_p2, c_k3 = rgb_vec5
# make a UV grid over depth_map
u, v = np.meshgrid(
np.arange(width_old),
np.arange(height_old),
)
# compute xyz coordinates for all depths
xyz_depth = np.stack((
(u - d_cx) / d_fx,
(v - d_cy) / d_fy,
depth_map,
np.ones(depth_map.shape)
)).reshape((4, -1))
xyz_depth = xyz_depth[:, xyz_depth[2] != 0]
# undistort depth coordinates
d_x, d_y = xyz_depth[:2]
r = np.linalg.norm(xyz_depth[:2], axis=0)
xyz_depth[0, :] \
= d_x / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
- (2*d_p1*d_x*d_y + d_p2*(r**2 + 2*d_x**2))
xyz_depth[1, :] \
= d_y / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
- (d_p1*(r**2 + 2*d_y**2) + 2*d_p2*d_x*d_y)
# unproject x and y
xyz_depth[0, :] *= xyz_depth[2, :]
xyz_depth[1, :] *= xyz_depth[2, :]
# convert depths to RGB camera viewpoint
xyz_rgb = ir_to_rgb_mat4 @ xyz_depth
# project depths to RGB canvas
rgb_z_inv = 1 / xyz_rgb[2] # perspective correction
rgb_uv = np.stack((
xyz_rgb[0] * rgb_z_inv * r_fx + r_cx + 0.5,
xyz_rgb[1] * rgb_z_inv * r_fy + r_cy + 0.5,
)).astype(np.int)
# mask of the rgb_xyz values within view of rgb_map
mask = reduce(operator.and_, [
rgb_uv[0] >= 0,
rgb_uv[1] >= 0,
rgb_uv[0] < width,
rgb_uv[1] < height,
])
if rgb_mask_map is not None:
mask[mask] &= rgb_mask_map[
rgb_uv[1, mask],
rgb_uv[0, mask]]
if not _output_points: # output image
output = np.zeros((height, width), dtype=depth_map.dtype)
output[
rgb_uv[1, mask],
rgb_uv[0, mask],
] = xyz_rgb[2, mask]
else: # output pointcloud
rgbs = rgb_map[ # lookup rgb values using rgb_uv
rgb_uv[1, mask],
rgb_uv[0, mask]]
output = np.stack((
xyz_rgb[0, mask], # x
xyz_rgb[1, mask], # y
xyz_rgb[2, mask], # z
rgbs[:, 0], # r
rgbs[:, 1], # g
rgbs[:, 2], # b
)).T
# output for realsies
if not _output_hits_uvs: #raw
return output
else: # with hit mask
uv = np.zeros((height, width), dtype=bool)
# filter points overlapping in the depth map
uv_indices = (
rgb_uv[1, mask],
rgb_uv[0, mask],
)
_, chosen = np.unique( uv_indices[0] << 32 | uv_indices[1], return_index=True )
output = output[chosen, :]
uv[uv_indices] = True
return output, uv
def join_rgb_and_depth_to_points(*a, **kw) -> np.ndarray:
return reorient_depth_map(*a, _output_points=True, **kw)
@compose(np.array) # block lru cache mutation
@lru_cache(maxsize=1)
@compose(list)
def generate_equidistant_sphere_points(
n : int,
centroid : np.ndarray = (0, 0, 0),
radius : float = 1,
compute_sphere_coordinates : bool = False,
compute_normals : bool = False,
shift_theta : bool = False,
) -> Iterable[tuple[float, ...]]:
# Deserno M. How to generate equidistributed points on the surface of a sphere
# https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf
if compute_sphere_coordinates and compute_normals:
raise ValueError(
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
)
n_count = 0
a = 4 * np.pi / n
d = np.sqrt(a)
n_theta = round(np.pi / d)
d_theta = np.pi / n_theta
d_phi = a / d_theta
for i in range(0, n_theta):
theta = np.pi * (i + 0.5) / n_theta
n_phi = round(2 * np.pi * np.sin(theta) / d_phi)
for j in range(0, n_phi):
phi = 2 * np.pi * j / n_phi
if compute_sphere_coordinates: # (theta, phi)
yield (
theta if shift_theta else theta - 0.5*np.pi,
phi,
)
elif compute_normals: # (x, y, z, nx, ny, nz)
yield (
centroid[0] + radius * np.sin(theta) * np.cos(phi),
centroid[1] + radius * np.sin(theta) * np.sin(phi),
centroid[2] + radius * np.cos(theta),
np.sin(theta) * np.cos(phi),
np.sin(theta) * np.sin(phi),
np.cos(theta),
)
else: # (x, y, z)
yield (
centroid[0] + radius * np.sin(theta) * np.cos(phi),
centroid[1] + radius * np.sin(theta) * np.sin(phi),
centroid[2] + radius * np.cos(theta),
)
n_count += 1
def generate_random_sphere_points(
n : int,
centroid : np.ndarray = (0, 0, 0),
radius : float = 1,
compute_sphere_coordinates : bool = False,
compute_normals : bool = False,
shift_theta : bool = False, # depends on convention
) -> np.ndarray:
if compute_sphere_coordinates and compute_normals:
raise ValueError(
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
)
theta = np.arcsin(np.random.uniform(-1, 1, n)) # inverse transform sampling
phi = np.random.uniform(0, 2*np.pi, n)
if compute_sphere_coordinates: # (theta, phi)
return np.stack((
theta if not shift_theta else 0.5*np.pi + theta,
phi,
), axis=1)
elif compute_normals: # (x, y, z, nx, ny, nz)
return np.stack((
centroid[0] + radius * np.cos(theta) * np.cos(phi),
centroid[1] + radius * np.cos(theta) * np.sin(phi),
centroid[2] + radius * np.sin(theta),
np.cos(theta) * np.cos(phi),
np.cos(theta) * np.sin(phi),
np.sin(theta),
), axis=1)
else: # (x, y, z)
return np.stack((
centroid[0] + radius * np.cos(theta) * np.cos(phi),
centroid[1] + radius * np.cos(theta) * np.sin(phi),
centroid[2] + radius * np.sin(theta),
), axis=1)

View File

@@ -0,0 +1,85 @@
from .h5_dataclasses import H5Dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Hashable, Optional, Callable
import os
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
__doc__ = """
Here are some helper functions for processing data.
"""
# multiprocessing does not work due to my rediculous use of closures, which seemingly cannot be pickled
# paralelize it in the shell instead
def precompute_data(
computer : Callable[[Hashable], Optional[H5Dataclass]],
identifiers : list[Hashable],
output_paths : list[Path],
page : tuple[int, int] = (0, 1),
*,
force : bool = False,
debug : bool = False,
):
"""
precomputes data and stores them as HDF5 datasets using `.to_file(path: Path)`
"""
page, n_pages = page
assert len(identifiers) == len(output_paths)
total = len(identifiers)
identifier_max_len = max(map(len, map(str, identifiers)))
t_epoch = None
def log(state: str, is_start = False):
nonlocal t_epoch
if is_start: t_epoch = datetime.now()
td = timedelta(0) if is_start else datetime.now() - t_epoch
print(" - "
f"{str(index+1).rjust(len(str(total)))}/{total}: "
f"{str(identifier).ljust(identifier_max_len)} @ {td}: {state}"
)
print(f"precompute_data(computer={computer.__module__}.{computer.__qualname__}, identifiers=..., force={force}, page={page})")
t_begin = datetime.now()
failed = []
# pagination
page_size = total // n_pages + bool(total % n_pages)
jobs = list(zip(identifiers, output_paths))[page_size*page : page_size*(page+1)]
for index, (identifier, output_path) in enumerate(jobs, start=page_size*page):
if not force and output_path.exists() and output_path.stat().st_size > 0:
continue
log("compute", is_start=True)
# compute
try:
res = computer(identifier)
except Exception as e:
failed.append(identifier)
log(f"failed compute: {e.__class__.__name__}: {e}")
if DEBUG or debug: raise e
continue
if res is None:
failed.append(identifier)
log("no result")
continue
# write to file
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
res.to_h5_file(output_path)
except Exception as e:
failed.append(identifier)
log(f"failed write: {e.__class__.__name__}: {e}")
if output_path.is_file(): output_path.unlink() # cleanup
if DEBUG or debug: raise e
continue
log("done")
print("precompute_data finished in", datetime.now() - t_begin)
print("failed:", failed or None)

768
ifield/data/common/scan.py Normal file
View File

@@ -0,0 +1,768 @@
from ...utils.helpers import compose
from . import points
from .h5_dataclasses import H5Dataclass, H5Array, H5ArrayNoSlice, TransformableDataclassMixin
from methodtools import lru_cache
from sklearn.neighbors import BallTree
import faiss
from trimesh import Trimesh
from typing import Iterable
from typing import Optional, TypeVar
import mesh_to_sdf
import mesh_to_sdf.scan as sdf_scan
import numpy as np
import trimesh
import trimesh.transformations as T
import warnings
__doc__ = """
Here are some helper types for data.
"""
_T = TypeVar("T")
class InvalidateLRUOnWriteMixin:
def __setattr__(self, key, value):
if not key.startswith("__wire|"):
for attr in dir(self):
if attr.startswith("__wire|"):
getattr(self, attr).cache_clear()
return super().__setattr__(key, value)
def lru_property(func):
return lru_cache(maxsize=1)(property(func))
class SingleViewScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
points_hit : H5ArrayNoSlice # (N, 3)
normals_hit : Optional[H5ArrayNoSlice] # (N, 3)
points_miss : H5ArrayNoSlice # (M, 3)
distances_miss : Optional[H5ArrayNoSlice] # (M)
colors_hit : Optional[H5ArrayNoSlice] # (N, 3)
colors_miss : Optional[H5ArrayNoSlice] # (M, 3)
uv_hits : Optional[H5ArrayNoSlice] # (H, W) dtype=bool
uv_miss : Optional[H5ArrayNoSlice] # (H, W) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backfaces)
cam_pos : H5ArrayNoSlice # (3)
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
out = self if inplace else self.copy(deep=False)
out.points_hit = T.transform_points(self.points_hit, mat4)
out.normals_hit = T.transform_points(self.normals_hit, mat4) if self.normals_hit is not None else None
out.points_miss = T.transform_points(self.points_miss, mat4)
out.distances_miss = self.distances_miss * scale_xyz
out.cam_pos = T.transform_points(self.points_cam, mat4)[-1]
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
return out
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
assert not self.has_miss_distances
if not self.is_hitting:
raise ValueError("No hits to compute the ray distance towards")
out = self.copy(deep=deep) if copy else self
out.distances_miss \
= distance_from_rays_to_point_cloud(
ray_origins = out.points_cam,
ray_dirs = out.ray_dirs_miss,
points = out.points_hit,
).astype(out.points_cam.dtype)
return out
@lru_property
def points(self) -> np.ndarray: # (N+M+1, 3)
return np.concatenate((
self.points_hit,
self.points_miss,
self.points_cam,
))
@lru_property
def uv_points(self) -> np.ndarray: # (N+M+1, 3)
if not self.has_uv: raise ValueError
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.points_hit.dtype)
out[self.uv_hits, :] = self.points_hit
out[self.uv_miss, :] = self.points_miss
return out
@lru_property
def uv_normals(self) -> np.ndarray: # (N+M+1, 3)
if not self.has_uv: raise ValueError
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.normals_hit.dtype)
out[self.uv_hits, :] = self.normals_hit
return out
@lru_property
def points_cam(self) -> Optional[np.ndarray]: # (1, 3)
if self.cam_pos is None: return None
return self.cam_pos[None, :]
@lru_property
def points_hit_centroid(self) -> np.ndarray:
return self.points_hit.mean(axis=0)
@lru_property
def points_hit_std(self) -> np.ndarray:
return self.points_hit.std(axis=0)
@lru_property
def is_hitting(self) -> bool:
return len(self.points_hit) > 0
@lru_property
def is_empty(self) -> bool:
return not (len(self.points_hit) or len(self.points_miss))
@lru_property
def has_colors(self) -> bool:
return self.colors_hit is not None or self.colors_miss is not None
@lru_property
def has_normals(self) -> bool:
return self.normals_hit is not None
@lru_property
def has_uv(self) -> bool:
return self.uv_hits is not None
@lru_property
def has_miss_distances(self) -> bool:
return self.distances_miss is not None
@lru_property
def xyzrgb_hit(self) -> np.ndarray: # (N, 6)
if self.colors_hit is None: raise ValueError
return np.concatenate([self.points_hit, self.colors_hit], axis=1)
@lru_property
def xyzrgb_miss(self) -> np.ndarray: # (M, 6)
if self.colors_miss is None: raise ValueError
return np.concatenate([self.points_miss, self.colors_miss], axis=1)
@lru_property
def ray_dirs_hit(self) -> np.ndarray: # (N, 3)
out = self.points_hit - self.points_cam
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
return out
@lru_property
def ray_dirs_miss(self) -> np.ndarray: # (N, 3)
out = self.points_miss - self.points_cam
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
return out
@classmethod
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewScan":
if "phi" not in kw and not "theta" in kw:
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
scan = sample_single_view_scan_from_mesh(mesh, **kw)
if compute_miss_distances and scan.is_hitting:
scan.compute_miss_distances()
return scan
def to_uv_scan(self) -> "SingleViewUVScan":
return SingleViewUVScan.from_scan(self)
@classmethod
def from_uv_scan(self, uvscan: "SingleViewUVScan") -> "SingleViewUVScan":
return uvscan.to_scan()
# The same, but with support for pagination (should have been this way since the start...)
class SingleViewUVScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
# B may be (N) or (H, W), the latter may be flattened
hits : H5Array # (*B) dtype=bool
miss : H5Array # (*B) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backface hits)
points : H5Array # (*B, 3) on far plane if miss, NaN if neither hit or miss
normals : Optional[H5Array] # (*B, 3) NaN if not hit
colors : Optional[H5Array] # (*B, 3)
distances : Optional[H5Array] # (*B) NaN if not miss
cam_pos : Optional[H5ArrayNoSlice] # (3) or (*B, 3)
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
@classmethod
def from_scan(cls, scan: SingleViewScan):
if not scan.has_uv:
raise ValueError("Scan cloud has no UV data")
hits, miss = scan.uv_hits, scan.uv_miss
dtype = scan.points_hit.dtype
assert hits.ndim in (1, 2), hits.ndim
assert hits.shape == miss.shape, (hits.shape, miss.shape)
points = np.full((*hits.shape, 3), np.nan, dtype=dtype)
points[hits, :] = scan.points_hit
points[miss, :] = scan.points_miss
normals = None
if scan.has_normals:
normals = np.full((*hits.shape, 3), np.nan, dtype=dtype)
normals[hits, :] = scan.normals_hit
distances = None
if scan.has_miss_distances:
distances = np.full(hits.shape, np.nan, dtype=dtype)
distances[miss] = scan.distances_miss
colors = None
if scan.has_colors:
colors = np.full((*hits.shape, 3), np.nan, dtype=dtype)
if scan.colors_hit is not None:
colors[hits, :] = scan.colors_hit
if scan.colors_miss is not None:
colors[miss, :] = scan.colors_miss
return cls(
hits = hits,
miss = miss,
points = points,
normals = normals,
colors = colors,
distances = distances,
cam_pos = scan.cam_pos,
cam_mat4 = scan.cam_mat4,
proj_mat4 = scan.proj_mat4,
transforms = scan.transforms,
)
def to_scan(self) -> "SingleViewScan":
if not self.is_single_view: raise ValueError
return SingleViewScan(
points_hit = self.points [self.hits, :],
points_miss = self.points [self.miss, :],
normals_hit = self.normals [self.hits, :] if self.has_normals else None,
distances_miss = self.distances[self.miss] if self.has_miss_distances else None,
colors_hit = self.colors [self.hits, :] if self.has_colors else None,
colors_miss = self.colors [self.miss, :] if self.has_colors else None,
uv_hits = self.hits,
uv_miss = self.miss,
cam_pos = self.cam_pos,
cam_mat4 = self.cam_mat4,
proj_mat4 = self.proj_mat4,
transforms = self.transforms,
)
def to_mesh(self) -> trimesh.Trimesh:
faces: list[(tuple[int, int],)*3] = []
for x in range(self.hits.shape[0]-1):
for y in range(self.hits.shape[1]-1):
c11 = x, y
c12 = x, y+1
c22 = x+1, y+1
c21 = x+1, y
n = sum(map(self.hits.__getitem__, (c11, c12, c22, c21)))
if n == 3:
faces.append((*filter(self.hits.__getitem__, (c11, c12, c22, c21)),))
elif n == 4:
faces.append((c11, c12, c22))
faces.append((c11, c22, c21))
xy2idx = {c:i for i, c in enumerate(set(k for j in faces for k in j))}
assert self.colors is not None
return trimesh.Trimesh(
vertices = [self.points[i] for i in xy2idx.keys()],
vertex_colors = [self.colors[i] for i in xy2idx.keys()] if self.colors is not None else None,
faces = [tuple(xy2idx[i] for i in face) for face in faces],
)
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
unflat = self.hits.shape
flat = np.product(unflat)
out = self if inplace else self.copy(deep=False)
out.points = T.transform_points(self.points .reshape((*flat, 3)), mat4).reshape((*unflat, 3))
out.normals = T.transform_points(self.normals.reshape((*flat, 3)), mat4).reshape((*unflat, 3)) if self.normals_hit is not None else None
out.distances = self.distances_miss * scale_xyz
out.cam_pos = T.transform_points(self.cam_pos[None, ...], mat4)[0]
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
return out
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False, surface_points: Optional[np.ndarray] = None) -> _T:
assert not self.has_miss_distances
shape = self.hits.shape
out = self.copy(deep=deep) if copy else self
out.distances = np.zeros(shape, dtype=self.points.dtype)
if self.is_hitting:
out.distances[self.miss] \
= distance_from_rays_to_point_cloud(
ray_origins = self.cam_pos_unsqueezed_miss,
ray_dirs = self.ray_dirs_miss,
points = surface_points if surface_points is not None else self.points[self.hits],
)
return out
def fill_missing_points(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
"""
Fill in missing points as hitting the far plane.
"""
if not self.is_2d:
raise ValueError("Cannot fill missing points for non-2d scan!")
if not self.is_single_view:
raise ValueError("Cannot fill missing points for non-single-view scans!")
if self.cam_mat4 is None:
raise ValueError("cam_mat4 is None")
if self.proj_mat4 is None:
raise ValueError("proj_mat4 is None")
uv = np.argwhere(self.missing).astype(self.points.dtype)
uv[:, 0] /= (self.missing.shape[1] - 1) / 2
uv[:, 1] /= (self.missing.shape[0] - 1) / 2
uv -= 1
uv = np.stack((
uv[:, 1],
-uv[:, 0],
np.ones(uv.shape[0]), # far clipping plane
np.ones(uv.shape[0]), # homogeneous coordinate
), axis=-1)
uv = uv @ (self.cam_mat4 @ np.linalg.inv(self.proj_mat4)).T
out = self.copy(deep=deep) if copy else self
out.points[self.missing, :] = uv[:, :3] / uv[:, 3][:, None]
return out
@lru_property
def is_hitting(self) -> bool:
return np.any(self.hits)
@lru_property
def has_colors(self) -> bool:
return not self.colors is None
@lru_property
def has_normals(self) -> bool:
return not self.normals is None
@lru_property
def has_miss_distances(self) -> bool:
return not self.distances is None
@lru_property
def any_missing(self) -> bool:
return np.any(self.missing)
@lru_property
def has_missing(self) -> bool:
return self.any_missing and not np.any(np.isnan(self.points[self.missing]))
@lru_property
def cam_pos_unsqueezed(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos
else:
cam_pos = self.cam_pos
for _ in range(self.hits.ndim):
cam_pos = cam_pos[None, ...]
return cam_pos
@lru_property
def cam_pos_unsqueezed_hit(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos[self.hits, :]
else:
return self.cam_pos[None, :]
@lru_property
def cam_pos_unsqueezed_miss(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos[self.miss, :]
else:
return self.cam_pos[None, :]
@lru_property
def ray_dirs(self) -> H5Array:
return (self.points - self.cam_pos_unsqueezed) * (1 / self.depths[..., None])
@lru_property
def ray_dirs_hit(self) -> H5Array:
out = self.points[self.hits, :] - self.cam_pos_unsqueezed_hit
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
return out
@lru_property
def ray_dirs_miss(self) -> H5Array:
out = self.points[self.miss, :] - self.cam_pos_unsqueezed_miss
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
return out
@lru_property
def depths(self) -> H5Array:
return np.linalg.norm(self.points - self.cam_pos_unsqueezed, axis=-1)
@lru_property
def missing(self) -> H5Array:
return ~(self.hits | self.miss)
@classmethod
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
if "phi" not in kw and not "theta" in kw:
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
scan = sample_single_view_scan_from_mesh(mesh, **kw).to_uv_scan()
if compute_miss_distances:
scan.compute_miss_distances()
assert scan.is_2d
return scan
@classmethod
def from_mesh_sphere_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
scan = sample_sphere_view_scan_from_mesh(mesh, **kw)
if compute_miss_distances:
surface_points = None
if scan.hits.sum() > mesh.vertices.shape[0]:
surface_points = mesh.vertices.astype(scan.points.dtype)
if not kw.get("no_unit_sphere", False):
translation, scale = compute_unit_sphere_transform(mesh, dtype=scan.points.dtype)
surface_points = (surface_points + translation) * scale
scan.compute_miss_distances(surface_points=surface_points)
assert scan.is_flat
return scan
def flatten_and_permute_(self: _T, copy=False) -> _T: # inplace by default
n_items = np.product(self.hits.shape)
permutation = np.random.permutation(n_items)
out = self.copy(deep=False) if copy else self
out.hits = out.hits .reshape((n_items, ))[permutation]
out.miss = out.miss .reshape((n_items, ))[permutation]
out.points = out.points .reshape((n_items, 3))[permutation, :]
out.normals = out.normals .reshape((n_items, 3))[permutation, :] if out.has_normals else None
out.colors = out.colors .reshape((n_items, 3))[permutation, :] if out.has_colors else None
out.distances = out.distances.reshape((n_items, ))[permutation] if out.has_miss_distances else None
return out
@property
def is_single_view(self) -> bool:
return np.product(self.cam_pos.shape[:-1]) == 1 if not self.cam_pos is None else True
@property
def is_flat(self) -> bool:
return len(self.hits.shape) == 1
@property
def is_2d(self) -> bool:
return len(self.hits.shape) == 2
# transforms can be found in pytorch3d.transforms and in open3d
# and in trimesh.transformations
def sample_single_view_scans_from_mesh(
mesh : Trimesh,
*,
n_batches : int,
scan_resolution : int = 400,
compute_normals : bool = False,
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
camera_distance : float = 2,
no_filter_backhits : bool = False,
) -> Iterable[SingleViewScan]:
normalized_mesh_cache = []
for _ in range(n_batches):
theta, phi = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
yield sample_single_view_scan_from_mesh(
mesh = mesh,
phi = phi,
theta = theta,
_mesh_is_normalized = False,
scan_resolution = scan_resolution,
compute_normals = compute_normals,
fov = fov,
camera_distance = camera_distance,
no_filter_backhits = no_filter_backhits,
_mesh_cache = normalized_mesh_cache,
)
def sample_single_view_scan_from_mesh(
mesh : Trimesh,
*,
phi : float,
theta : float,
scan_resolution : int = 200,
compute_normals : bool = False,
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
camera_distance : float = 2,
no_filter_backhits : bool = False,
no_unit_sphere : bool = False,
dtype : type = np.float32,
_mesh_cache : Optional[list] = None, # provide a list if mesh is reused
) -> SingleViewScan:
# scale and center to unit sphere
is_cache = isinstance(_mesh_cache, list)
if is_cache and _mesh_cache and _mesh_cache[0] is mesh:
_, mesh, translation, scale = _mesh_cache
else:
if is_cache:
if _mesh_cache:
_mesh_cache.clear()
_mesh_cache.append(mesh)
translation, scale = compute_unit_sphere_transform(mesh)
mesh = mesh_to_sdf.scale_to_unit_sphere(mesh)
if is_cache:
_mesh_cache.extend((mesh, translation, scale))
z_near = 1
z_far = 3
cam_mat4 = sdf_scan.get_camera_transform_looking_at_origin(phi, theta, camera_distance=camera_distance)
cam_pos = cam_mat4 @ np.array([0, 0, 0, 1])
scan = sdf_scan.Scan(mesh,
camera_transform = cam_mat4,
resolution = scan_resolution,
calculate_normals = compute_normals,
fov = fov,
z_near = z_near,
z_far = z_far,
no_flip_backfaced_normals = True
)
# all the scan rays that hit the far plane, based on sdf_scan.Scan.__init__
misses = np.argwhere(scan.depth_buffer == 0)
points_miss = np.ones((misses.shape[0], 4))
points_miss[:, [1, 0]] = misses.astype(float) / (scan_resolution -1) * 2 - 1
points_miss[:, 1] *= -1
points_miss[:, 2] = 1 # far plane in clipping space
points_miss = points_miss @ (cam_mat4 @ np.linalg.inv(scan.projection_matrix)).T
points_miss /= points_miss[:, 3][:, np.newaxis]
points_miss = points_miss[:, :3]
uv_hits = scan.depth_buffer != 0
uv_miss = ~uv_hits
if not no_filter_backhits:
if not compute_normals:
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
# inner product
mask = np.einsum('ij,ij->i', scan.points - cam_pos[:3][None, :], scan.normals) < 0
scan.points = scan.points [mask, :]
scan.normals = scan.normals[mask, :]
uv_hits[uv_hits] = mask
transforms = {}
# undo unit-sphere transform
if no_unit_sphere:
scan.points = scan.points * (1 / scale) - translation
points_miss = points_miss * (1 / scale) - translation
cam_pos[:3] = cam_pos[:3] * (1 / scale) - translation
cam_mat4[:3, :] *= 1 / scale
cam_mat4[:3, 3] -= translation
transforms["unit_sphere"] = T.scale_and_translate(scale=scale, translate=translation)
transforms["model"] = np.eye(4)
else:
transforms["model"] = np.linalg.inv(T.scale_and_translate(scale=scale, translate=translation))
transforms["unit_sphere"] = np.eye(4)
return SingleViewScan(
normals_hit = scan.normals .astype(dtype),
points_hit = scan.points .astype(dtype),
points_miss = points_miss .astype(dtype),
distances_miss = None,
colors_hit = None,
colors_miss = None,
uv_hits = uv_hits .astype(bool),
uv_miss = uv_miss .astype(bool),
cam_pos = cam_pos[:3] .astype(dtype),
cam_mat4 = cam_mat4 .astype(dtype),
proj_mat4 = scan.projection_matrix .astype(dtype),
transforms = {k:v.astype(dtype) for k, v in transforms.items()},
)
def sample_sphere_view_scan_from_mesh(
mesh : Trimesh,
*,
sphere_points : int = 4000, # resulting rays are n*(n-1)
compute_normals : bool = False,
no_filter_backhits : bool = False,
no_unit_sphere : bool = False,
no_permute : bool = False,
dtype : type = np.float32,
**kw,
) -> SingleViewUVScan:
translation, scale = compute_unit_sphere_transform(mesh, dtype=dtype)
# get unit-sphere points, then transform to model space
two_sphere = generate_equidistant_sphere_rays(sphere_points, **kw).astype(dtype) # (n*(n-1), 2, 3)
two_sphere = two_sphere / scale - translation # we transform after cache lookup
if mesh.ray.__class__.__module__.split(".")[-1] != "ray_pyembree":
warnings.warn("Pyembree not found, the ray-tracing will be SLOW!")
(
locations,
index_ray,
index_tri,
) = mesh.ray.intersects_location(
two_sphere[:, 0, :],
two_sphere[:, 1, :] - two_sphere[:, 0, :], # direction, not target coordinate
multiple_hits=False,
)
if compute_normals:
location_normals = mesh.face_normals[index_tri]
batch = two_sphere.shape[:1]
hits = np.zeros((*batch,), dtype=np.bool)
miss = np.ones((*batch,), dtype=np.bool)
cam_pos = two_sphere[:, 0, :]
intersections = two_sphere[:, 1, :] # far-plane, effectively
normals = np.zeros((*batch, 3), dtype=dtype)
index_ray_front = index_ray
if not no_filter_backhits:
if not compute_normals:
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
mask = ((intersections[index_ray] - cam_pos[index_ray]) * location_normals).sum(axis=-1) <= 0
index_ray_front = index_ray[mask]
hits[index_ray_front] = True
miss[index_ray] = False
intersections[index_ray] = locations
normals[index_ray] = location_normals
if not no_permute:
assert len(batch) == 1, batch
permutation = np.random.permutation(*batch)
hits = hits [permutation]
miss = miss [permutation]
intersections = intersections[permutation, :]
normals = normals [permutation, :]
cam_pos = cam_pos [permutation, :]
# apply unit sphere transform
if not no_unit_sphere:
intersections = (intersections + translation) * scale
cam_pos = (cam_pos + translation) * scale
return SingleViewUVScan(
hits = hits,
miss = miss,
points = intersections,
normals = normals,
colors = None, # colors
distances = None,
cam_pos = cam_pos,
cam_mat4 = None,
proj_mat4 = None,
transforms = {},
)
def distance_from_rays_to_point_cloud(
ray_origins : np.ndarray, # (*A, 3)
ray_dirs : np.ndarray, # (*A, 3)
points : np.ndarray, # (*B, 3)
dirs_normalized : bool = False,
n_steps : int = 40,
) -> np.ndarray: # (A)
# anything outside of this volume will never constribute to the result
max_norm = max(
np.linalg.norm(ray_origins, axis=-1).max(),
np.linalg.norm(points, axis=-1).max(),
) * 1.02
if not dirs_normalized:
ray_dirs = ray_dirs / np.linalg.norm(ray_dirs, axis=-1)[..., None]
# deal with single-view clouds
if ray_origins.shape != ray_dirs.shape:
ray_origins = np.broadcast_to(ray_origins, ray_dirs.shape)
n_points = np.product(points.shape[:-1])
use_faiss = n_points > 160000*4
if not use_faiss:
index = BallTree(points)
else:
# http://ann-benchmarks.com/index.html
assert np.issubdtype(points.dtype, np.float32)
assert np.issubdtype(ray_origins.dtype, np.float32)
assert np.issubdtype(ray_dirs.dtype, np.float32)
index = faiss.index_factory(points.shape[-1], "NSG32,Flat") # https://github.com/facebookresearch/faiss/wiki/The-index-factory
index.nprobe = 5 # 10 # default is 1
index.train(points)
index.add(points)
if not use_faiss:
min_d, min_n = index.query(ray_origins, k=1, return_distance=True)
else:
min_d, min_n = index.search(ray_origins, k=1)
min_d = np.sqrt(min_d)
acc_d = min_d.copy()
for step in range(1, n_steps+1):
query_points = ray_origins + acc_d * ray_dirs
if max_norm is not None:
qmask = np.linalg.norm(query_points, axis=-1) < max_norm
if not qmask.any(): break
query_points = query_points[qmask]
else:
qmask = slice(None)
if not use_faiss:
current_d, current_n = index.query(query_points, k=1, return_distance=True)
else:
current_d, current_n = index.search(query_points, k=1)
current_d = np.sqrt(current_d)
if max_norm is not None:
min_d[qmask] = np.minimum(current_d, min_d[qmask])
new_min_mask = min_d[qmask] == current_d
qmask2 = qmask.copy()
qmask2[qmask2] = new_min_mask[..., 0]
min_n[qmask2] = current_n[new_min_mask[..., 0]]
acc_d[qmask] += current_d * 0.25
else:
np.minimum(current_d, min_d, out=min_d)
new_min_mask = min_d == current_d
min_n[new_min_mask] = current_n[new_min_mask]
acc_d += current_d * 0.25
closest_points = points[min_n[:, 0], :] # k=1
distances = np.linalg.norm(np.cross(closest_points - ray_origins, ray_dirs, axis=-1), axis=-1)
return distances
# helpers
@compose(np.array) # make copy to avoid lru cache mutation
@lru_cache(maxsize=1)
def generate_equidistant_sphere_rays(n : int, **kw) -> np.ndarray: # output (n*n(-1)) rays, n may be off
sphere_points = points.generate_equidistant_sphere_points(n=n, **kw)
indices = np.indices((len(sphere_points),))[0] # (N)
# cartesian product
cprod = np.transpose([np.tile(indices, len(indices)), np.repeat(indices, len(indices))]) # (N**2, 2)
# filter repeated combinations
permutations = cprod[cprod[:, 0] != cprod[:, 1], :] # (N*(N-1), 2)
# lookup sphere points
two_sphere = sphere_points[permutations, :] # (N*(N-1), 2, 3)
return two_sphere
def compute_unit_sphere_transform(mesh: Trimesh, *, dtype=type) -> tuple[np.ndarray, float]:
"""
returns translation and scale which mesh_to_sdf applies to meshes before computing their SDF cloud
"""
# the transformation applied by mesh_to_sdf.scale_to_unit_sphere(mesh)
translation = -mesh.bounding_box.centroid
scale = 1 / np.max(np.linalg.norm(mesh.vertices + translation, axis=1))
if dtype is not None:
translation = translation.astype(dtype)
scale = scale .astype(dtype)
return translation, scale

View File

@@ -0,0 +1,6 @@
__doc__ = """
Some helper types.
"""
class MalformedMesh(Exception):
pass