Add code
This commit is contained in:
0
ifield/data/common/__init__.py
Normal file
0
ifield/data/common/__init__.py
Normal file
90
ifield/data/common/download.py
Normal file
90
ifield/data/common/download.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Optional
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
|
||||
PathLike = Union[os.PathLike, str]
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
def check_url(url): # HTTP HEAD
|
||||
return requests.head(url).ok
|
||||
|
||||
def download_stream(
|
||||
url : str,
|
||||
file_object,
|
||||
block_size : int = 1024,
|
||||
silent : bool = False,
|
||||
label : Optional[str] = None,
|
||||
):
|
||||
resp = requests.get(url, stream=True)
|
||||
total_size = int(resp.headers.get("content-length", 0))
|
||||
if not silent:
|
||||
progress_bar = tqdm(total=total_size , unit="iB", unit_scale=True, desc=label)
|
||||
|
||||
for chunk in resp.iter_content(block_size):
|
||||
if not silent:
|
||||
progress_bar.update(len(chunk))
|
||||
file_object.write(chunk)
|
||||
|
||||
if not silent:
|
||||
progress_bar.close()
|
||||
if total_size != 0 and progress_bar.n != total_size:
|
||||
print("ERROR, something went wrong")
|
||||
|
||||
def download_data(
|
||||
url : str,
|
||||
block_size : int = 1024,
|
||||
silent : bool = False,
|
||||
label : Optional[str] = None,
|
||||
) -> bytearray:
|
||||
f = io.BytesIO()
|
||||
download_stream(url, f, block_size=block_size, silent=silent, label=label)
|
||||
f.seek(0)
|
||||
return bytearray(f.read())
|
||||
|
||||
def download_file(
|
||||
url : str,
|
||||
fname : Union[Path, str],
|
||||
block_size : int = 1024,
|
||||
silent = False,
|
||||
):
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
with fname.open("wb") as f:
|
||||
download_stream(url, f, block_size=block_size, silent=silent, label=make_relative(fname, Path.cwd()).name)
|
||||
|
||||
def is_downloaded(
|
||||
target_dir : PathLike,
|
||||
url : str,
|
||||
*,
|
||||
add : bool = False,
|
||||
dbfiles : Union[list[PathLike], PathLike],
|
||||
):
|
||||
if not isinstance(target_dir, os.PathLike):
|
||||
target_dir = Path(target_dir)
|
||||
if not isinstance(dbfiles, list):
|
||||
dbfiles = [dbfiles]
|
||||
if not dbfiles:
|
||||
raise ValueError("'dbfiles' empty")
|
||||
downloaded = set()
|
||||
for dbfile_fname in dbfiles:
|
||||
dbfile_fname = target_dir / dbfile_fname
|
||||
if dbfile_fname.is_file():
|
||||
with open(dbfile_fname, "r") as f:
|
||||
downloaded.update(json.load(f)["downloaded"])
|
||||
|
||||
if add and url not in downloaded:
|
||||
downloaded.add(url)
|
||||
with open(dbfiles[0], "w") as f:
|
||||
data = {"downloaded": sorted(downloaded)}
|
||||
json.dump(data, f, indent=2, sort_keys=True)
|
||||
return True
|
||||
|
||||
return url in downloaded
|
||||
370
ifield/data/common/h5_dataclasses.py
Normal file
370
ifield/data/common/h5_dataclasses.py
Normal file
@@ -0,0 +1,370 @@
|
||||
#!/usr/bin/env python3
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import h5py as h5
|
||||
import hdf5plugin
|
||||
import numpy as np
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
|
||||
__all__ = [
|
||||
"DataclassMeta",
|
||||
"Dataclass",
|
||||
"H5Dataclass",
|
||||
"H5Array",
|
||||
"H5ArrayNoSlice",
|
||||
]
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
NoneType = type(None)
|
||||
PathLike = typing.Union[os.PathLike, str]
|
||||
H5Array = typing._alias(np.ndarray, 0, inst=False, name="H5Array")
|
||||
H5ArrayNoSlice = typing._alias(np.ndarray, 0, inst=False, name="H5ArrayNoSlice")
|
||||
|
||||
DataclassField = namedtuple("DataclassField", [
|
||||
"name",
|
||||
"type",
|
||||
"is_optional",
|
||||
"is_array",
|
||||
"is_sliceable",
|
||||
"is_prefix",
|
||||
])
|
||||
|
||||
def strip_optional(val: type) -> type:
|
||||
if typing.get_origin(val) is typing.Union:
|
||||
union = set(typing.get_args(val))
|
||||
if len(union - {NoneType}) == 1:
|
||||
val, = union - {NoneType}
|
||||
else:
|
||||
raise TypeError(f"Non-'typing.Optional' 'typing.Union' is not supported: {typing._type_repr(val)!r}")
|
||||
return val
|
||||
|
||||
def is_array(val, *, _inner=False):
|
||||
"""
|
||||
Hacky way to check if a value or type is an array.
|
||||
The hack omits having to depend on large frameworks such as pytorch or pandas
|
||||
"""
|
||||
val = strip_optional(val)
|
||||
if val is H5Array or val is H5ArrayNoSlice:
|
||||
return True
|
||||
|
||||
if typing._type_repr(val) in (
|
||||
"numpy.ndarray",
|
||||
"torch.Tensor",
|
||||
):
|
||||
return True
|
||||
if not _inner:
|
||||
return is_array(type(val), _inner=True)
|
||||
return False
|
||||
|
||||
def prod(numbers: typing.Iterable[T], initial: typing.Optional[T] = None) -> T:
|
||||
if initial is not None:
|
||||
return functools.reduce(operator.mul, numbers, initial)
|
||||
else:
|
||||
return functools.reduce(operator.mul, numbers)
|
||||
|
||||
class DataclassMeta(type):
|
||||
def __new__(
|
||||
mcls,
|
||||
name : str,
|
||||
bases : tuple[type, ...],
|
||||
attrs : dict[str, typing.Any],
|
||||
**kwargs,
|
||||
):
|
||||
cls = super().__new__(mcls, name, bases, attrs, **kwargs)
|
||||
if sys.version_info[:2] >= (3, 10) and not hasattr(cls, "__slots__"):
|
||||
cls = dataclasses.dataclass(slots=True)(cls)
|
||||
else:
|
||||
cls = dataclasses.dataclass(cls)
|
||||
return cls
|
||||
|
||||
class DataclassABCMeta(DataclassMeta, ABCMeta):
|
||||
pass
|
||||
|
||||
class Dataclass(metaclass=DataclassMeta):
|
||||
def __getitem__(self, key: str) -> typing.Any:
|
||||
if key in self.keys():
|
||||
return getattr(self, key)
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: typing.Any):
|
||||
if key in self.keys():
|
||||
return setattr(self, key, value)
|
||||
raise KeyError(key)
|
||||
|
||||
def keys(self) -> typing.KeysView:
|
||||
return self.as_dict().keys()
|
||||
|
||||
def values(self) -> typing.ValuesView:
|
||||
return self.as_dict().values()
|
||||
|
||||
def items(self) -> typing.ItemsView:
|
||||
return self.as_dict().items()
|
||||
|
||||
def as_dict(self, properties_to_include: set[str] = None, **kw) -> dict[str, typing.Any]:
|
||||
out = dataclasses.asdict(self, **kw)
|
||||
for name in (properties_to_include or []):
|
||||
out[name] = getattr(self, name)
|
||||
return out
|
||||
|
||||
def as_tuple(self, properties_to_include: list[str]) -> tuple:
|
||||
out = dataclasses.astuple(self)
|
||||
if not properties_to_include:
|
||||
return out
|
||||
else:
|
||||
return (
|
||||
*out,
|
||||
*(getattr(self, name) for name in properties_to_include),
|
||||
)
|
||||
|
||||
def copy(self: T, *, deep=True) -> T:
|
||||
return (copy.deepcopy if deep else copy.copy)(self)
|
||||
|
||||
class H5Dataclass(Dataclass):
|
||||
# settable with class params:
|
||||
_prefix : str = dataclasses.field(init=False, repr=False, default="")
|
||||
_n_pages : int = dataclasses.field(init=False, repr=False, default=10)
|
||||
_require_all : bool = dataclasses.field(init=False, repr=False, default=False)
|
||||
|
||||
def __init_subclass__(cls,
|
||||
prefix : typing.Optional[str] = None,
|
||||
n_pages : typing.Optional[int] = None,
|
||||
require_all : typing.Optional[bool] = None,
|
||||
**kw,
|
||||
):
|
||||
super().__init_subclass__(**kw)
|
||||
assert dataclasses.is_dataclass(cls)
|
||||
if prefix is not None: cls._prefix = prefix
|
||||
if n_pages is not None: cls._n_pages = n_pages
|
||||
if require_all is not None: cls._require_all = require_all
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls) -> typing.Iterable[DataclassField]:
|
||||
for field in dataclasses.fields(cls):
|
||||
if not field.init:
|
||||
continue
|
||||
assert field.name not in ("_prefix", "_n_pages", "_require_all"), (
|
||||
f"{field.name!r} can not be in {cls.__qualname__}.__init__.\n"
|
||||
"Set it with dataclasses.field(default=YOUR_VALUE, init=False, repr=False)"
|
||||
)
|
||||
if isinstance(field.type, str):
|
||||
raise TypeError("Type hints are strings, perhaps avoid using `from __future__ import annotations`")
|
||||
|
||||
type_inner = strip_optional(field.type)
|
||||
is_prefix = typing.get_origin(type_inner) is dict and typing.get_args(type_inner)[:1] == (str,)
|
||||
field_type = typing.get_args(type_inner)[1] if is_prefix else field.type
|
||||
if field.default is None or typing.get_origin(field.type) is typing.Union and NoneType in typing.get_args(field.type):
|
||||
field_type = typing.Optional[field_type]
|
||||
|
||||
yield DataclassField(
|
||||
name = field.name,
|
||||
type = strip_optional(field_type),
|
||||
is_optional = typing.get_origin(field_type) is typing.Union and NoneType in typing.get_args(field_type),
|
||||
is_array = is_array(field_type),
|
||||
is_sliceable = is_array(field_type) and strip_optional(field_type) is not H5ArrayNoSlice,
|
||||
is_prefix = is_prefix,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_h5_file(cls : type[T],
|
||||
fname : typing.Union[PathLike, str],
|
||||
*,
|
||||
page : typing.Optional[int] = None,
|
||||
n_pages : typing.Optional[int] = None,
|
||||
read_slice : slice = slice(None),
|
||||
require_even_pages : bool = True,
|
||||
) -> T:
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
if n_pages is None:
|
||||
n_pages = cls._n_pages
|
||||
if not fname.exists():
|
||||
raise FileNotFoundError(str(fname))
|
||||
if not h5.is_hdf5(fname):
|
||||
raise TypeError(f"Not a HDF5 file: {str(fname)!r}")
|
||||
|
||||
# if this class has no fields, print a example class:
|
||||
if not any(field.init for field in dataclasses.fields(cls)):
|
||||
with h5.File(fname, "r") as f:
|
||||
klen = max(map(len, f.keys()))
|
||||
example_cls = f"\nclass {cls.__name__}(Dataclass, require_all=True):\n" + "\n".join(
|
||||
f" {k.ljust(klen)} : "
|
||||
+ (
|
||||
"H5Array" if prod(v.shape, 1) > 1 else (
|
||||
"float" if issubclass(v.dtype.type, np.floating) else (
|
||||
"int" if issubclass(v.dtype.type, np.integer) else (
|
||||
"bool" if issubclass(v.dtype.type, np.bool_) else (
|
||||
"typing.Any"
|
||||
))))).ljust(14 + 1)
|
||||
+ f" #{repr(v).split(':', 1)[1].removesuffix('>')}"
|
||||
for k, v in f.items()
|
||||
)
|
||||
raise NotImplementedError(f"{cls!r} has no fields!\nPerhaps try the following:{example_cls}")
|
||||
|
||||
fields_consumed = set()
|
||||
|
||||
def make_kwarg(
|
||||
file : h5.File,
|
||||
keys : typing.KeysView,
|
||||
field : DataclassField,
|
||||
) -> tuple[str, typing.Any]:
|
||||
if field.is_optional:
|
||||
if field.name not in keys:
|
||||
return field.name, None
|
||||
if field.is_sliceable:
|
||||
if page is not None:
|
||||
n_items = int(f[cls._prefix + field.name].shape[0])
|
||||
page_len = n_items // n_pages
|
||||
modulus = n_items % n_pages
|
||||
if modulus: page_len += 1 # round up
|
||||
if require_even_pages and modulus:
|
||||
raise ValueError(f"Field {field.name!r} {tuple(f[cls._prefix + field.name].shape)} is not cleanly divisible into {n_pages} pages")
|
||||
this_slice = slice(
|
||||
start = page_len * page,
|
||||
stop = page_len * (page+1),
|
||||
step = read_slice.step, # inherit step
|
||||
)
|
||||
else:
|
||||
this_slice = read_slice
|
||||
else:
|
||||
this_slice = slice(None) # read all
|
||||
|
||||
# array or scalar?
|
||||
def read_dataset(var):
|
||||
# https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
|
||||
if field.is_array:
|
||||
return var[this_slice]
|
||||
if var.shape == (1,):
|
||||
return var[0]
|
||||
else:
|
||||
return var[()]
|
||||
|
||||
if field.is_prefix:
|
||||
fields_consumed.update(
|
||||
key
|
||||
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
||||
)
|
||||
return field.name, {
|
||||
key.removeprefix(f"{cls._prefix}{field.name}_") : read_dataset(file[key])
|
||||
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
||||
}
|
||||
else:
|
||||
fields_consumed.add(cls._prefix + field.name)
|
||||
return field.name, read_dataset(file[cls._prefix + field.name])
|
||||
|
||||
with h5.File(fname, "r") as f:
|
||||
keys = f.keys()
|
||||
init_dict = dict( make_kwarg(f, keys, i) for i in cls._get_fields() )
|
||||
|
||||
try:
|
||||
out = cls(**init_dict)
|
||||
except Exception as e:
|
||||
class_attrs = set(field.name for field in dataclasses.fields(cls))
|
||||
file_attr = set(init_dict.keys())
|
||||
raise e.__class__(f"{e}. {class_attrs=}, {file_attr=}, diff={class_attrs.symmetric_difference(file_attr)}") from e
|
||||
|
||||
if cls._require_all:
|
||||
fields_not_consumed = set(keys) - fields_consumed
|
||||
if fields_not_consumed:
|
||||
raise ValueError(f"Not all HDF5 fields consumed: {fields_not_consumed!r}")
|
||||
|
||||
return out
|
||||
|
||||
def to_h5_file(self,
|
||||
fname : PathLike,
|
||||
mkdir : bool = False,
|
||||
):
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
if not fname.parent.is_dir():
|
||||
if mkdir:
|
||||
fname.parent.mkdir(parents=True)
|
||||
else:
|
||||
raise NotADirectoryError(fname.parent)
|
||||
|
||||
with h5.File(fname, "w") as f:
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_optional and getattr(self, field.name) is None:
|
||||
continue
|
||||
value = getattr(self, field.name)
|
||||
if field.is_array:
|
||||
if any(type(i) is not np.ndarray for i in (value.values() if field.is_prefix else [value])):
|
||||
raise TypeError(
|
||||
"When dumping a H5Dataclass, make sure the array fields are "
|
||||
f"numpy arrays (the type of {field.name!r} is {typing._type_repr(type(value))}).\n"
|
||||
"Example: h5dataclass.map_arrays(torch.Tensor.numpy)"
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
def write_value(key: str, value: typing.Any):
|
||||
if field.is_array:
|
||||
f.create_dataset(key, data=value, **hdf5plugin.LZ4())
|
||||
else:
|
||||
f.create_dataset(key, data=value)
|
||||
|
||||
if field.is_prefix:
|
||||
for k, v in value.items():
|
||||
write_value(self._prefix + field.name + "_" + k, v)
|
||||
else:
|
||||
write_value(self._prefix + field.name, value)
|
||||
|
||||
def map_arrays(self: T, func: typing.Callable[[H5Array], H5Array], do_copy: bool = False) -> T:
|
||||
if do_copy: # shallow
|
||||
self = self.copy(deep=False)
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_optional and getattr(self, field.name) is None:
|
||||
continue
|
||||
if field.is_prefix and field.is_array:
|
||||
setattr(self, field.name, {
|
||||
k : func(v)
|
||||
for k, v in getattr(self, field.name).items()
|
||||
})
|
||||
elif field.is_array:
|
||||
setattr(self, field.name, func(getattr(self, field.name)))
|
||||
|
||||
return self
|
||||
|
||||
def astype(self: T, t: type, do_copy: bool = False, convert_nonfloats: bool = False) -> T:
|
||||
return self.map_arrays(lambda x: x.astype(t) if convert_nonfloats or not np.issubdtype(x.dtype, int) else x)
|
||||
|
||||
def copy(self: T, *, deep=True) -> T:
|
||||
out = super().copy(deep=deep)
|
||||
if not deep:
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_prefix:
|
||||
out[field.name] = copy.copy(field.name)
|
||||
return out
|
||||
|
||||
@property
|
||||
def shape(self) -> dict[str, tuple[int, ...]]:
|
||||
return {
|
||||
key: value.shape
|
||||
for key, value in self.items()
|
||||
if hasattr(value, "shape")
|
||||
}
|
||||
|
||||
class TransformableDataclassMixin(metaclass=DataclassABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
def transform(self: T, mat4: np.ndarray, inplace=False) -> T:
|
||||
...
|
||||
|
||||
def transform_to(self: T, name: str, inverse_name: str = None, *, inplace=False) -> T:
|
||||
mtx = self.transforms[name]
|
||||
out = self.transform(mtx, inplace=inplace)
|
||||
out.transforms.pop(name) # consumed
|
||||
|
||||
inv = np.linalg.inv(mtx)
|
||||
for key in list(out.transforms.keys()): # maintain the other transforms
|
||||
out.transforms[key] = out.transforms[key] @ inv
|
||||
if inverse_name is not None: # store inverse
|
||||
out.transforms[inverse_name] = inv
|
||||
|
||||
return out
|
||||
48
ifield/data/common/mesh.py
Normal file
48
ifield/data/common/mesh.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from math import pi
|
||||
from trimesh import Trimesh
|
||||
import numpy as np
|
||||
import os
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
def rotate_to_closest_axis_aligned_bounds(
|
||||
mesh : Trimesh,
|
||||
order_axes : bool = True,
|
||||
fail_ok : bool = True,
|
||||
) -> np.ndarray:
|
||||
to_origin_mat4, extents = trimesh.bounds.oriented_bounds(mesh, ordered=not order_axes)
|
||||
to_aabb_rot_mat4 = T.euler_matrix(*T.decompose_matrix(to_origin_mat4)[3])
|
||||
|
||||
if not order_axes:
|
||||
return to_aabb_rot_mat4
|
||||
|
||||
v = pi / 4 * 1.01 # tolerance
|
||||
v2 = pi / 2
|
||||
|
||||
faces = (
|
||||
(0, 0),
|
||||
(1, 0),
|
||||
(2, 0),
|
||||
(3, 0),
|
||||
(0, 1),
|
||||
(0,-1),
|
||||
)
|
||||
orientations = [ # 6 faces x 4 rotations per face
|
||||
(f[0] * v2, f[1] * v2, i * v2)
|
||||
for i in range(4)
|
||||
for f in faces]
|
||||
|
||||
for x, y, z in orientations:
|
||||
mat4 = T.euler_matrix(x, y, z) @ to_aabb_rot_mat4
|
||||
ai, aj, ak = T.euler_from_matrix(mat4)
|
||||
if abs(ai) <= v and abs(aj) <= v and abs(ak) <= v:
|
||||
return mat4
|
||||
|
||||
if fail_ok: return to_aabb_rot_mat4
|
||||
raise Exception("Unable to orient mesh")
|
||||
297
ifield/data/common/points.py
Normal file
297
ifield/data/common/points.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from __future__ import annotations
|
||||
from ...utils.helpers import compose
|
||||
from functools import reduce, lru_cache
|
||||
from math import ceil
|
||||
from typing import Iterable
|
||||
import numpy as np
|
||||
import operator
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
|
||||
def img2col(img: np.ndarray, psize: int) -> np.ndarray:
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
|
||||
n_channels = 1 if len(img.shape) == 2 else img.shape[0]
|
||||
n_channels, rows, cols = (1,) * (3 - len(img.shape)) + img.shape
|
||||
|
||||
# pad the image
|
||||
img_pad = np.zeros((
|
||||
n_channels,
|
||||
int(ceil(1.0 * rows / psize) * psize),
|
||||
int(ceil(1.0 * cols / psize) * psize),
|
||||
))
|
||||
img_pad[:, 0:rows, 0:cols] = img
|
||||
|
||||
# allocate output buffer
|
||||
final = np.zeros((
|
||||
img_pad.shape[1],
|
||||
img_pad.shape[2],
|
||||
n_channels,
|
||||
psize,
|
||||
psize,
|
||||
))
|
||||
|
||||
for c in range(n_channels):
|
||||
for x in range(psize):
|
||||
for y in range(psize):
|
||||
img_shift = np.vstack((
|
||||
img_pad[c, x:],
|
||||
img_pad[c, :x]))
|
||||
img_shift = np.column_stack((
|
||||
img_shift[:, y:],
|
||||
img_shift[:, :y]))
|
||||
final[x::psize, y::psize, c] = np.swapaxes(
|
||||
img_shift.reshape(
|
||||
int(img_pad.shape[1] / psize), psize,
|
||||
int(img_pad.shape[2] / psize), psize),
|
||||
1,
|
||||
2)
|
||||
|
||||
# crop output and unwrap axes with size==1
|
||||
return np.squeeze(final[
|
||||
0:rows - psize + 1,
|
||||
0:cols - psize + 1])
|
||||
|
||||
def filter_depth_discontinuities(depth_map: np.ndarray, filt_size = 7, thresh = 1000) -> np.ndarray:
|
||||
"""
|
||||
Removes data close to discontinuities, with size filt_size.
|
||||
"""
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
|
||||
# Ensure that filter sizes are okay
|
||||
assert filt_size % 2, "Can only use odd filter sizes."
|
||||
|
||||
# Compute discontinuities
|
||||
offset = int(filt_size - 1) // 2
|
||||
patches = 1.0 * img2col(depth_map, filt_size)
|
||||
mids = patches[:, :, offset, offset]
|
||||
mins = np.min(patches, axis=(2, 3))
|
||||
maxes = np.max(patches, axis=(2, 3))
|
||||
|
||||
discont = np.maximum(
|
||||
np.abs(mins - mids),
|
||||
np.abs(maxes - mids))
|
||||
mark = discont > thresh
|
||||
|
||||
# Account for offsets
|
||||
final_mark = np.zeros(depth_map.shape, dtype=np.uint16)
|
||||
final_mark[offset:offset + mark.shape[0],
|
||||
offset:offset + mark.shape[1]] = mark
|
||||
|
||||
return depth_map * (1 - final_mark)
|
||||
|
||||
def reorient_depth_map(
|
||||
depth_map : np.ndarray,
|
||||
rgb_map : np.ndarray,
|
||||
depth_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
|
||||
depth_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
|
||||
rgb_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
|
||||
rgb_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
|
||||
ir_to_rgb_mat4 : np.ndarray, # extrinsic transformation matrix from depth to rgb camera viewpoint
|
||||
rgb_mask_map : np.ndarray = None,
|
||||
_output_points = False, # retval (H, W) if false else (N, XYZRGB)
|
||||
_output_hits_uvs = False, # retval[1] is dtype=bool of hits shaped like depth_map
|
||||
) -> np.ndarray:
|
||||
|
||||
"""
|
||||
Corrects depth_map to be from the same view as the rgb_map, with the same dimensions.
|
||||
If _output_points is True, the points returned are in the rgb camera space.
|
||||
"""
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
# now faster AND more easy on the GIL
|
||||
|
||||
height_old, width_old, *_ = depth_map.shape
|
||||
height, width, *_ = rgb_map.shape
|
||||
|
||||
|
||||
d_cx, r_cx = depth_mat3[0, 2], rgb_mat3[0, 2] # optical center
|
||||
d_cy, r_cy = depth_mat3[1, 2], rgb_mat3[1, 2]
|
||||
d_fx, r_fx = depth_mat3[0, 0], rgb_mat3[0, 0] # focal length
|
||||
d_fy, r_fy = depth_mat3[1, 1], rgb_mat3[1, 1]
|
||||
d_k1, d_k2, d_p1, d_p2, d_k3 = depth_vec5
|
||||
c_k1, c_k2, c_p1, c_p2, c_k3 = rgb_vec5
|
||||
|
||||
# make a UV grid over depth_map
|
||||
u, v = np.meshgrid(
|
||||
np.arange(width_old),
|
||||
np.arange(height_old),
|
||||
)
|
||||
|
||||
# compute xyz coordinates for all depths
|
||||
xyz_depth = np.stack((
|
||||
(u - d_cx) / d_fx,
|
||||
(v - d_cy) / d_fy,
|
||||
depth_map,
|
||||
np.ones(depth_map.shape)
|
||||
)).reshape((4, -1))
|
||||
xyz_depth = xyz_depth[:, xyz_depth[2] != 0]
|
||||
|
||||
# undistort depth coordinates
|
||||
d_x, d_y = xyz_depth[:2]
|
||||
r = np.linalg.norm(xyz_depth[:2], axis=0)
|
||||
xyz_depth[0, :] \
|
||||
= d_x / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
|
||||
- (2*d_p1*d_x*d_y + d_p2*(r**2 + 2*d_x**2))
|
||||
xyz_depth[1, :] \
|
||||
= d_y / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
|
||||
- (d_p1*(r**2 + 2*d_y**2) + 2*d_p2*d_x*d_y)
|
||||
|
||||
# unproject x and y
|
||||
xyz_depth[0, :] *= xyz_depth[2, :]
|
||||
xyz_depth[1, :] *= xyz_depth[2, :]
|
||||
|
||||
# convert depths to RGB camera viewpoint
|
||||
xyz_rgb = ir_to_rgb_mat4 @ xyz_depth
|
||||
|
||||
# project depths to RGB canvas
|
||||
rgb_z_inv = 1 / xyz_rgb[2] # perspective correction
|
||||
rgb_uv = np.stack((
|
||||
xyz_rgb[0] * rgb_z_inv * r_fx + r_cx + 0.5,
|
||||
xyz_rgb[1] * rgb_z_inv * r_fy + r_cy + 0.5,
|
||||
)).astype(np.int)
|
||||
|
||||
# mask of the rgb_xyz values within view of rgb_map
|
||||
mask = reduce(operator.and_, [
|
||||
rgb_uv[0] >= 0,
|
||||
rgb_uv[1] >= 0,
|
||||
rgb_uv[0] < width,
|
||||
rgb_uv[1] < height,
|
||||
])
|
||||
if rgb_mask_map is not None:
|
||||
mask[mask] &= rgb_mask_map[
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask]]
|
||||
|
||||
if not _output_points: # output image
|
||||
output = np.zeros((height, width), dtype=depth_map.dtype)
|
||||
output[
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask],
|
||||
] = xyz_rgb[2, mask]
|
||||
|
||||
else: # output pointcloud
|
||||
rgbs = rgb_map[ # lookup rgb values using rgb_uv
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask]]
|
||||
output = np.stack((
|
||||
xyz_rgb[0, mask], # x
|
||||
xyz_rgb[1, mask], # y
|
||||
xyz_rgb[2, mask], # z
|
||||
rgbs[:, 0], # r
|
||||
rgbs[:, 1], # g
|
||||
rgbs[:, 2], # b
|
||||
)).T
|
||||
|
||||
# output for realsies
|
||||
if not _output_hits_uvs: #raw
|
||||
return output
|
||||
else: # with hit mask
|
||||
uv = np.zeros((height, width), dtype=bool)
|
||||
# filter points overlapping in the depth map
|
||||
uv_indices = (
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask],
|
||||
)
|
||||
_, chosen = np.unique( uv_indices[0] << 32 | uv_indices[1], return_index=True )
|
||||
output = output[chosen, :]
|
||||
uv[uv_indices] = True
|
||||
return output, uv
|
||||
|
||||
def join_rgb_and_depth_to_points(*a, **kw) -> np.ndarray:
|
||||
return reorient_depth_map(*a, _output_points=True, **kw)
|
||||
|
||||
@compose(np.array) # block lru cache mutation
|
||||
@lru_cache(maxsize=1)
|
||||
@compose(list)
|
||||
def generate_equidistant_sphere_points(
|
||||
n : int,
|
||||
centroid : np.ndarray = (0, 0, 0),
|
||||
radius : float = 1,
|
||||
compute_sphere_coordinates : bool = False,
|
||||
compute_normals : bool = False,
|
||||
shift_theta : bool = False,
|
||||
) -> Iterable[tuple[float, ...]]:
|
||||
# Deserno M. How to generate equidistributed points on the surface of a sphere
|
||||
# https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf
|
||||
|
||||
if compute_sphere_coordinates and compute_normals:
|
||||
raise ValueError(
|
||||
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
|
||||
)
|
||||
|
||||
n_count = 0
|
||||
a = 4 * np.pi / n
|
||||
d = np.sqrt(a)
|
||||
n_theta = round(np.pi / d)
|
||||
d_theta = np.pi / n_theta
|
||||
d_phi = a / d_theta
|
||||
|
||||
for i in range(0, n_theta):
|
||||
theta = np.pi * (i + 0.5) / n_theta
|
||||
n_phi = round(2 * np.pi * np.sin(theta) / d_phi)
|
||||
|
||||
for j in range(0, n_phi):
|
||||
phi = 2 * np.pi * j / n_phi
|
||||
|
||||
if compute_sphere_coordinates: # (theta, phi)
|
||||
yield (
|
||||
theta if shift_theta else theta - 0.5*np.pi,
|
||||
phi,
|
||||
)
|
||||
elif compute_normals: # (x, y, z, nx, ny, nz)
|
||||
yield (
|
||||
centroid[0] + radius * np.sin(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.sin(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.cos(theta),
|
||||
np.sin(theta) * np.cos(phi),
|
||||
np.sin(theta) * np.sin(phi),
|
||||
np.cos(theta),
|
||||
)
|
||||
else: # (x, y, z)
|
||||
yield (
|
||||
centroid[0] + radius * np.sin(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.sin(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.cos(theta),
|
||||
)
|
||||
n_count += 1
|
||||
|
||||
|
||||
def generate_random_sphere_points(
|
||||
n : int,
|
||||
centroid : np.ndarray = (0, 0, 0),
|
||||
radius : float = 1,
|
||||
compute_sphere_coordinates : bool = False,
|
||||
compute_normals : bool = False,
|
||||
shift_theta : bool = False, # depends on convention
|
||||
) -> np.ndarray:
|
||||
if compute_sphere_coordinates and compute_normals:
|
||||
raise ValueError(
|
||||
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
|
||||
)
|
||||
|
||||
theta = np.arcsin(np.random.uniform(-1, 1, n)) # inverse transform sampling
|
||||
phi = np.random.uniform(0, 2*np.pi, n)
|
||||
|
||||
if compute_sphere_coordinates: # (theta, phi)
|
||||
return np.stack((
|
||||
theta if not shift_theta else 0.5*np.pi + theta,
|
||||
phi,
|
||||
), axis=1)
|
||||
elif compute_normals: # (x, y, z, nx, ny, nz)
|
||||
return np.stack((
|
||||
centroid[0] + radius * np.cos(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.cos(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.sin(theta),
|
||||
np.cos(theta) * np.cos(phi),
|
||||
np.cos(theta) * np.sin(phi),
|
||||
np.sin(theta),
|
||||
), axis=1)
|
||||
else: # (x, y, z)
|
||||
return np.stack((
|
||||
centroid[0] + radius * np.cos(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.cos(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.sin(theta),
|
||||
), axis=1)
|
||||
85
ifield/data/common/processing.py
Normal file
85
ifield/data/common/processing.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from .h5_dataclasses import H5Dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Hashable, Optional, Callable
|
||||
import os
|
||||
|
||||
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
# multiprocessing does not work due to my rediculous use of closures, which seemingly cannot be pickled
|
||||
# paralelize it in the shell instead
|
||||
|
||||
def precompute_data(
|
||||
computer : Callable[[Hashable], Optional[H5Dataclass]],
|
||||
identifiers : list[Hashable],
|
||||
output_paths : list[Path],
|
||||
page : tuple[int, int] = (0, 1),
|
||||
*,
|
||||
force : bool = False,
|
||||
debug : bool = False,
|
||||
):
|
||||
"""
|
||||
precomputes data and stores them as HDF5 datasets using `.to_file(path: Path)`
|
||||
"""
|
||||
|
||||
page, n_pages = page
|
||||
assert len(identifiers) == len(output_paths)
|
||||
|
||||
total = len(identifiers)
|
||||
identifier_max_len = max(map(len, map(str, identifiers)))
|
||||
t_epoch = None
|
||||
def log(state: str, is_start = False):
|
||||
nonlocal t_epoch
|
||||
if is_start: t_epoch = datetime.now()
|
||||
td = timedelta(0) if is_start else datetime.now() - t_epoch
|
||||
print(" - "
|
||||
f"{str(index+1).rjust(len(str(total)))}/{total}: "
|
||||
f"{str(identifier).ljust(identifier_max_len)} @ {td}: {state}"
|
||||
)
|
||||
|
||||
print(f"precompute_data(computer={computer.__module__}.{computer.__qualname__}, identifiers=..., force={force}, page={page})")
|
||||
t_begin = datetime.now()
|
||||
failed = []
|
||||
|
||||
# pagination
|
||||
page_size = total // n_pages + bool(total % n_pages)
|
||||
jobs = list(zip(identifiers, output_paths))[page_size*page : page_size*(page+1)]
|
||||
|
||||
for index, (identifier, output_path) in enumerate(jobs, start=page_size*page):
|
||||
if not force and output_path.exists() and output_path.stat().st_size > 0:
|
||||
continue
|
||||
|
||||
log("compute", is_start=True)
|
||||
|
||||
# compute
|
||||
try:
|
||||
res = computer(identifier)
|
||||
except Exception as e:
|
||||
failed.append(identifier)
|
||||
log(f"failed compute: {e.__class__.__name__}: {e}")
|
||||
if DEBUG or debug: raise e
|
||||
continue
|
||||
if res is None:
|
||||
failed.append(identifier)
|
||||
log("no result")
|
||||
continue
|
||||
|
||||
# write to file
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
res.to_h5_file(output_path)
|
||||
except Exception as e:
|
||||
failed.append(identifier)
|
||||
log(f"failed write: {e.__class__.__name__}: {e}")
|
||||
if output_path.is_file(): output_path.unlink() # cleanup
|
||||
if DEBUG or debug: raise e
|
||||
continue
|
||||
|
||||
log("done")
|
||||
|
||||
print("precompute_data finished in", datetime.now() - t_begin)
|
||||
print("failed:", failed or None)
|
||||
768
ifield/data/common/scan.py
Normal file
768
ifield/data/common/scan.py
Normal file
@@ -0,0 +1,768 @@
|
||||
from ...utils.helpers import compose
|
||||
from . import points
|
||||
from .h5_dataclasses import H5Dataclass, H5Array, H5ArrayNoSlice, TransformableDataclassMixin
|
||||
from methodtools import lru_cache
|
||||
from sklearn.neighbors import BallTree
|
||||
import faiss
|
||||
from trimesh import Trimesh
|
||||
from typing import Iterable
|
||||
from typing import Optional, TypeVar
|
||||
import mesh_to_sdf
|
||||
import mesh_to_sdf.scan as sdf_scan
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
import warnings
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper types for data.
|
||||
"""
|
||||
|
||||
_T = TypeVar("T")
|
||||
|
||||
class InvalidateLRUOnWriteMixin:
|
||||
def __setattr__(self, key, value):
|
||||
if not key.startswith("__wire|"):
|
||||
for attr in dir(self):
|
||||
if attr.startswith("__wire|"):
|
||||
getattr(self, attr).cache_clear()
|
||||
return super().__setattr__(key, value)
|
||||
def lru_property(func):
|
||||
return lru_cache(maxsize=1)(property(func))
|
||||
|
||||
class SingleViewScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
|
||||
points_hit : H5ArrayNoSlice # (N, 3)
|
||||
normals_hit : Optional[H5ArrayNoSlice] # (N, 3)
|
||||
points_miss : H5ArrayNoSlice # (M, 3)
|
||||
distances_miss : Optional[H5ArrayNoSlice] # (M)
|
||||
colors_hit : Optional[H5ArrayNoSlice] # (N, 3)
|
||||
colors_miss : Optional[H5ArrayNoSlice] # (M, 3)
|
||||
uv_hits : Optional[H5ArrayNoSlice] # (H, W) dtype=bool
|
||||
uv_miss : Optional[H5ArrayNoSlice] # (H, W) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backfaces)
|
||||
cam_pos : H5ArrayNoSlice # (3)
|
||||
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
|
||||
|
||||
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
|
||||
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
|
||||
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
|
||||
|
||||
out = self if inplace else self.copy(deep=False)
|
||||
out.points_hit = T.transform_points(self.points_hit, mat4)
|
||||
out.normals_hit = T.transform_points(self.normals_hit, mat4) if self.normals_hit is not None else None
|
||||
out.points_miss = T.transform_points(self.points_miss, mat4)
|
||||
out.distances_miss = self.distances_miss * scale_xyz
|
||||
out.cam_pos = T.transform_points(self.points_cam, mat4)[-1]
|
||||
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
|
||||
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
|
||||
return out
|
||||
|
||||
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
|
||||
assert not self.has_miss_distances
|
||||
if not self.is_hitting:
|
||||
raise ValueError("No hits to compute the ray distance towards")
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.distances_miss \
|
||||
= distance_from_rays_to_point_cloud(
|
||||
ray_origins = out.points_cam,
|
||||
ray_dirs = out.ray_dirs_miss,
|
||||
points = out.points_hit,
|
||||
).astype(out.points_cam.dtype)
|
||||
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def points(self) -> np.ndarray: # (N+M+1, 3)
|
||||
return np.concatenate((
|
||||
self.points_hit,
|
||||
self.points_miss,
|
||||
self.points_cam,
|
||||
))
|
||||
|
||||
@lru_property
|
||||
def uv_points(self) -> np.ndarray: # (N+M+1, 3)
|
||||
if not self.has_uv: raise ValueError
|
||||
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.points_hit.dtype)
|
||||
out[self.uv_hits, :] = self.points_hit
|
||||
out[self.uv_miss, :] = self.points_miss
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def uv_normals(self) -> np.ndarray: # (N+M+1, 3)
|
||||
if not self.has_uv: raise ValueError
|
||||
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.normals_hit.dtype)
|
||||
out[self.uv_hits, :] = self.normals_hit
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def points_cam(self) -> Optional[np.ndarray]: # (1, 3)
|
||||
if self.cam_pos is None: return None
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def points_hit_centroid(self) -> np.ndarray:
|
||||
return self.points_hit.mean(axis=0)
|
||||
|
||||
@lru_property
|
||||
def points_hit_std(self) -> np.ndarray:
|
||||
return self.points_hit.std(axis=0)
|
||||
|
||||
@lru_property
|
||||
def is_hitting(self) -> bool:
|
||||
return len(self.points_hit) > 0
|
||||
|
||||
@lru_property
|
||||
def is_empty(self) -> bool:
|
||||
return not (len(self.points_hit) or len(self.points_miss))
|
||||
|
||||
@lru_property
|
||||
def has_colors(self) -> bool:
|
||||
return self.colors_hit is not None or self.colors_miss is not None
|
||||
|
||||
@lru_property
|
||||
def has_normals(self) -> bool:
|
||||
return self.normals_hit is not None
|
||||
|
||||
@lru_property
|
||||
def has_uv(self) -> bool:
|
||||
return self.uv_hits is not None
|
||||
|
||||
@lru_property
|
||||
def has_miss_distances(self) -> bool:
|
||||
return self.distances_miss is not None
|
||||
|
||||
@lru_property
|
||||
def xyzrgb_hit(self) -> np.ndarray: # (N, 6)
|
||||
if self.colors_hit is None: raise ValueError
|
||||
return np.concatenate([self.points_hit, self.colors_hit], axis=1)
|
||||
|
||||
@lru_property
|
||||
def xyzrgb_miss(self) -> np.ndarray: # (M, 6)
|
||||
if self.colors_miss is None: raise ValueError
|
||||
return np.concatenate([self.points_miss, self.colors_miss], axis=1)
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_hit(self) -> np.ndarray: # (N, 3)
|
||||
out = self.points_hit - self.points_cam
|
||||
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_miss(self) -> np.ndarray: # (N, 3)
|
||||
out = self.points_miss - self.points_cam
|
||||
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewScan":
|
||||
if "phi" not in kw and not "theta" in kw:
|
||||
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
scan = sample_single_view_scan_from_mesh(mesh, **kw)
|
||||
if compute_miss_distances and scan.is_hitting:
|
||||
scan.compute_miss_distances()
|
||||
return scan
|
||||
|
||||
def to_uv_scan(self) -> "SingleViewUVScan":
|
||||
return SingleViewUVScan.from_scan(self)
|
||||
|
||||
@classmethod
|
||||
def from_uv_scan(self, uvscan: "SingleViewUVScan") -> "SingleViewUVScan":
|
||||
return uvscan.to_scan()
|
||||
|
||||
# The same, but with support for pagination (should have been this way since the start...)
|
||||
class SingleViewUVScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
|
||||
# B may be (N) or (H, W), the latter may be flattened
|
||||
hits : H5Array # (*B) dtype=bool
|
||||
miss : H5Array # (*B) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backface hits)
|
||||
points : H5Array # (*B, 3) on far plane if miss, NaN if neither hit or miss
|
||||
normals : Optional[H5Array] # (*B, 3) NaN if not hit
|
||||
colors : Optional[H5Array] # (*B, 3)
|
||||
distances : Optional[H5Array] # (*B) NaN if not miss
|
||||
cam_pos : Optional[H5ArrayNoSlice] # (3) or (*B, 3)
|
||||
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
|
||||
|
||||
@classmethod
|
||||
def from_scan(cls, scan: SingleViewScan):
|
||||
if not scan.has_uv:
|
||||
raise ValueError("Scan cloud has no UV data")
|
||||
hits, miss = scan.uv_hits, scan.uv_miss
|
||||
dtype = scan.points_hit.dtype
|
||||
assert hits.ndim in (1, 2), hits.ndim
|
||||
assert hits.shape == miss.shape, (hits.shape, miss.shape)
|
||||
|
||||
points = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
points[hits, :] = scan.points_hit
|
||||
points[miss, :] = scan.points_miss
|
||||
|
||||
normals = None
|
||||
if scan.has_normals:
|
||||
normals = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
normals[hits, :] = scan.normals_hit
|
||||
|
||||
distances = None
|
||||
if scan.has_miss_distances:
|
||||
distances = np.full(hits.shape, np.nan, dtype=dtype)
|
||||
distances[miss] = scan.distances_miss
|
||||
|
||||
colors = None
|
||||
if scan.has_colors:
|
||||
colors = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
if scan.colors_hit is not None:
|
||||
colors[hits, :] = scan.colors_hit
|
||||
if scan.colors_miss is not None:
|
||||
colors[miss, :] = scan.colors_miss
|
||||
|
||||
return cls(
|
||||
hits = hits,
|
||||
miss = miss,
|
||||
points = points,
|
||||
normals = normals,
|
||||
colors = colors,
|
||||
distances = distances,
|
||||
cam_pos = scan.cam_pos,
|
||||
cam_mat4 = scan.cam_mat4,
|
||||
proj_mat4 = scan.proj_mat4,
|
||||
transforms = scan.transforms,
|
||||
)
|
||||
|
||||
def to_scan(self) -> "SingleViewScan":
|
||||
if not self.is_single_view: raise ValueError
|
||||
return SingleViewScan(
|
||||
points_hit = self.points [self.hits, :],
|
||||
points_miss = self.points [self.miss, :],
|
||||
normals_hit = self.normals [self.hits, :] if self.has_normals else None,
|
||||
distances_miss = self.distances[self.miss] if self.has_miss_distances else None,
|
||||
colors_hit = self.colors [self.hits, :] if self.has_colors else None,
|
||||
colors_miss = self.colors [self.miss, :] if self.has_colors else None,
|
||||
uv_hits = self.hits,
|
||||
uv_miss = self.miss,
|
||||
cam_pos = self.cam_pos,
|
||||
cam_mat4 = self.cam_mat4,
|
||||
proj_mat4 = self.proj_mat4,
|
||||
transforms = self.transforms,
|
||||
)
|
||||
|
||||
def to_mesh(self) -> trimesh.Trimesh:
|
||||
faces: list[(tuple[int, int],)*3] = []
|
||||
for x in range(self.hits.shape[0]-1):
|
||||
for y in range(self.hits.shape[1]-1):
|
||||
c11 = x, y
|
||||
c12 = x, y+1
|
||||
c22 = x+1, y+1
|
||||
c21 = x+1, y
|
||||
|
||||
n = sum(map(self.hits.__getitem__, (c11, c12, c22, c21)))
|
||||
if n == 3:
|
||||
faces.append((*filter(self.hits.__getitem__, (c11, c12, c22, c21)),))
|
||||
elif n == 4:
|
||||
faces.append((c11, c12, c22))
|
||||
faces.append((c11, c22, c21))
|
||||
xy2idx = {c:i for i, c in enumerate(set(k for j in faces for k in j))}
|
||||
assert self.colors is not None
|
||||
return trimesh.Trimesh(
|
||||
vertices = [self.points[i] for i in xy2idx.keys()],
|
||||
vertex_colors = [self.colors[i] for i in xy2idx.keys()] if self.colors is not None else None,
|
||||
faces = [tuple(xy2idx[i] for i in face) for face in faces],
|
||||
)
|
||||
|
||||
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
|
||||
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
|
||||
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
|
||||
|
||||
unflat = self.hits.shape
|
||||
flat = np.product(unflat)
|
||||
|
||||
out = self if inplace else self.copy(deep=False)
|
||||
out.points = T.transform_points(self.points .reshape((*flat, 3)), mat4).reshape((*unflat, 3))
|
||||
out.normals = T.transform_points(self.normals.reshape((*flat, 3)), mat4).reshape((*unflat, 3)) if self.normals_hit is not None else None
|
||||
out.distances = self.distances_miss * scale_xyz
|
||||
out.cam_pos = T.transform_points(self.cam_pos[None, ...], mat4)[0]
|
||||
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
|
||||
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
|
||||
return out
|
||||
|
||||
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False, surface_points: Optional[np.ndarray] = None) -> _T:
|
||||
assert not self.has_miss_distances
|
||||
|
||||
shape = self.hits.shape
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.distances = np.zeros(shape, dtype=self.points.dtype)
|
||||
if self.is_hitting:
|
||||
out.distances[self.miss] \
|
||||
= distance_from_rays_to_point_cloud(
|
||||
ray_origins = self.cam_pos_unsqueezed_miss,
|
||||
ray_dirs = self.ray_dirs_miss,
|
||||
points = surface_points if surface_points is not None else self.points[self.hits],
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def fill_missing_points(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
|
||||
"""
|
||||
Fill in missing points as hitting the far plane.
|
||||
"""
|
||||
if not self.is_2d:
|
||||
raise ValueError("Cannot fill missing points for non-2d scan!")
|
||||
if not self.is_single_view:
|
||||
raise ValueError("Cannot fill missing points for non-single-view scans!")
|
||||
if self.cam_mat4 is None:
|
||||
raise ValueError("cam_mat4 is None")
|
||||
if self.proj_mat4 is None:
|
||||
raise ValueError("proj_mat4 is None")
|
||||
|
||||
uv = np.argwhere(self.missing).astype(self.points.dtype)
|
||||
uv[:, 0] /= (self.missing.shape[1] - 1) / 2
|
||||
uv[:, 1] /= (self.missing.shape[0] - 1) / 2
|
||||
uv -= 1
|
||||
uv = np.stack((
|
||||
uv[:, 1],
|
||||
-uv[:, 0],
|
||||
np.ones(uv.shape[0]), # far clipping plane
|
||||
np.ones(uv.shape[0]), # homogeneous coordinate
|
||||
), axis=-1)
|
||||
uv = uv @ (self.cam_mat4 @ np.linalg.inv(self.proj_mat4)).T
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.points[self.missing, :] = uv[:, :3] / uv[:, 3][:, None]
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def is_hitting(self) -> bool:
|
||||
return np.any(self.hits)
|
||||
|
||||
@lru_property
|
||||
def has_colors(self) -> bool:
|
||||
return not self.colors is None
|
||||
|
||||
@lru_property
|
||||
def has_normals(self) -> bool:
|
||||
return not self.normals is None
|
||||
|
||||
@lru_property
|
||||
def has_miss_distances(self) -> bool:
|
||||
return not self.distances is None
|
||||
|
||||
@lru_property
|
||||
def any_missing(self) -> bool:
|
||||
return np.any(self.missing)
|
||||
|
||||
@lru_property
|
||||
def has_missing(self) -> bool:
|
||||
return self.any_missing and not np.any(np.isnan(self.points[self.missing]))
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos
|
||||
else:
|
||||
cam_pos = self.cam_pos
|
||||
for _ in range(self.hits.ndim):
|
||||
cam_pos = cam_pos[None, ...]
|
||||
return cam_pos
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed_hit(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos[self.hits, :]
|
||||
else:
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed_miss(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos[self.miss, :]
|
||||
else:
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def ray_dirs(self) -> H5Array:
|
||||
return (self.points - self.cam_pos_unsqueezed) * (1 / self.depths[..., None])
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_hit(self) -> H5Array:
|
||||
out = self.points[self.hits, :] - self.cam_pos_unsqueezed_hit
|
||||
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_miss(self) -> H5Array:
|
||||
out = self.points[self.miss, :] - self.cam_pos_unsqueezed_miss
|
||||
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def depths(self) -> H5Array:
|
||||
return np.linalg.norm(self.points - self.cam_pos_unsqueezed, axis=-1)
|
||||
|
||||
@lru_property
|
||||
def missing(self) -> H5Array:
|
||||
return ~(self.hits | self.miss)
|
||||
|
||||
@classmethod
|
||||
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
|
||||
if "phi" not in kw and not "theta" in kw:
|
||||
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
scan = sample_single_view_scan_from_mesh(mesh, **kw).to_uv_scan()
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
assert scan.is_2d
|
||||
return scan
|
||||
|
||||
@classmethod
|
||||
def from_mesh_sphere_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
|
||||
scan = sample_sphere_view_scan_from_mesh(mesh, **kw)
|
||||
if compute_miss_distances:
|
||||
surface_points = None
|
||||
if scan.hits.sum() > mesh.vertices.shape[0]:
|
||||
surface_points = mesh.vertices.astype(scan.points.dtype)
|
||||
if not kw.get("no_unit_sphere", False):
|
||||
translation, scale = compute_unit_sphere_transform(mesh, dtype=scan.points.dtype)
|
||||
surface_points = (surface_points + translation) * scale
|
||||
scan.compute_miss_distances(surface_points=surface_points)
|
||||
assert scan.is_flat
|
||||
return scan
|
||||
|
||||
def flatten_and_permute_(self: _T, copy=False) -> _T: # inplace by default
|
||||
n_items = np.product(self.hits.shape)
|
||||
permutation = np.random.permutation(n_items)
|
||||
|
||||
out = self.copy(deep=False) if copy else self
|
||||
out.hits = out.hits .reshape((n_items, ))[permutation]
|
||||
out.miss = out.miss .reshape((n_items, ))[permutation]
|
||||
out.points = out.points .reshape((n_items, 3))[permutation, :]
|
||||
out.normals = out.normals .reshape((n_items, 3))[permutation, :] if out.has_normals else None
|
||||
out.colors = out.colors .reshape((n_items, 3))[permutation, :] if out.has_colors else None
|
||||
out.distances = out.distances.reshape((n_items, ))[permutation] if out.has_miss_distances else None
|
||||
return out
|
||||
|
||||
@property
|
||||
def is_single_view(self) -> bool:
|
||||
return np.product(self.cam_pos.shape[:-1]) == 1 if not self.cam_pos is None else True
|
||||
|
||||
@property
|
||||
def is_flat(self) -> bool:
|
||||
return len(self.hits.shape) == 1
|
||||
|
||||
@property
|
||||
def is_2d(self) -> bool:
|
||||
return len(self.hits.shape) == 2
|
||||
|
||||
|
||||
# transforms can be found in pytorch3d.transforms and in open3d
|
||||
# and in trimesh.transformations
|
||||
|
||||
def sample_single_view_scans_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
n_batches : int,
|
||||
scan_resolution : int = 400,
|
||||
compute_normals : bool = False,
|
||||
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
|
||||
camera_distance : float = 2,
|
||||
no_filter_backhits : bool = False,
|
||||
) -> Iterable[SingleViewScan]:
|
||||
|
||||
normalized_mesh_cache = []
|
||||
|
||||
for _ in range(n_batches):
|
||||
theta, phi = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
|
||||
yield sample_single_view_scan_from_mesh(
|
||||
mesh = mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
_mesh_is_normalized = False,
|
||||
scan_resolution = scan_resolution,
|
||||
compute_normals = compute_normals,
|
||||
fov = fov,
|
||||
camera_distance = camera_distance,
|
||||
no_filter_backhits = no_filter_backhits,
|
||||
_mesh_cache = normalized_mesh_cache,
|
||||
)
|
||||
|
||||
def sample_single_view_scan_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
scan_resolution : int = 200,
|
||||
compute_normals : bool = False,
|
||||
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
|
||||
camera_distance : float = 2,
|
||||
no_filter_backhits : bool = False,
|
||||
no_unit_sphere : bool = False,
|
||||
dtype : type = np.float32,
|
||||
_mesh_cache : Optional[list] = None, # provide a list if mesh is reused
|
||||
) -> SingleViewScan:
|
||||
|
||||
# scale and center to unit sphere
|
||||
is_cache = isinstance(_mesh_cache, list)
|
||||
if is_cache and _mesh_cache and _mesh_cache[0] is mesh:
|
||||
_, mesh, translation, scale = _mesh_cache
|
||||
else:
|
||||
if is_cache:
|
||||
if _mesh_cache:
|
||||
_mesh_cache.clear()
|
||||
_mesh_cache.append(mesh)
|
||||
translation, scale = compute_unit_sphere_transform(mesh)
|
||||
mesh = mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
if is_cache:
|
||||
_mesh_cache.extend((mesh, translation, scale))
|
||||
|
||||
z_near = 1
|
||||
z_far = 3
|
||||
cam_mat4 = sdf_scan.get_camera_transform_looking_at_origin(phi, theta, camera_distance=camera_distance)
|
||||
cam_pos = cam_mat4 @ np.array([0, 0, 0, 1])
|
||||
|
||||
scan = sdf_scan.Scan(mesh,
|
||||
camera_transform = cam_mat4,
|
||||
resolution = scan_resolution,
|
||||
calculate_normals = compute_normals,
|
||||
fov = fov,
|
||||
z_near = z_near,
|
||||
z_far = z_far,
|
||||
no_flip_backfaced_normals = True
|
||||
)
|
||||
|
||||
# all the scan rays that hit the far plane, based on sdf_scan.Scan.__init__
|
||||
misses = np.argwhere(scan.depth_buffer == 0)
|
||||
points_miss = np.ones((misses.shape[0], 4))
|
||||
points_miss[:, [1, 0]] = misses.astype(float) / (scan_resolution -1) * 2 - 1
|
||||
points_miss[:, 1] *= -1
|
||||
points_miss[:, 2] = 1 # far plane in clipping space
|
||||
points_miss = points_miss @ (cam_mat4 @ np.linalg.inv(scan.projection_matrix)).T
|
||||
points_miss /= points_miss[:, 3][:, np.newaxis]
|
||||
points_miss = points_miss[:, :3]
|
||||
|
||||
uv_hits = scan.depth_buffer != 0
|
||||
uv_miss = ~uv_hits
|
||||
|
||||
if not no_filter_backhits:
|
||||
if not compute_normals:
|
||||
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
|
||||
# inner product
|
||||
mask = np.einsum('ij,ij->i', scan.points - cam_pos[:3][None, :], scan.normals) < 0
|
||||
scan.points = scan.points [mask, :]
|
||||
scan.normals = scan.normals[mask, :]
|
||||
uv_hits[uv_hits] = mask
|
||||
|
||||
transforms = {}
|
||||
|
||||
# undo unit-sphere transform
|
||||
if no_unit_sphere:
|
||||
scan.points = scan.points * (1 / scale) - translation
|
||||
points_miss = points_miss * (1 / scale) - translation
|
||||
cam_pos[:3] = cam_pos[:3] * (1 / scale) - translation
|
||||
cam_mat4[:3, :] *= 1 / scale
|
||||
cam_mat4[:3, 3] -= translation
|
||||
|
||||
transforms["unit_sphere"] = T.scale_and_translate(scale=scale, translate=translation)
|
||||
transforms["model"] = np.eye(4)
|
||||
else:
|
||||
transforms["model"] = np.linalg.inv(T.scale_and_translate(scale=scale, translate=translation))
|
||||
transforms["unit_sphere"] = np.eye(4)
|
||||
|
||||
return SingleViewScan(
|
||||
normals_hit = scan.normals .astype(dtype),
|
||||
points_hit = scan.points .astype(dtype),
|
||||
points_miss = points_miss .astype(dtype),
|
||||
distances_miss = None,
|
||||
colors_hit = None,
|
||||
colors_miss = None,
|
||||
uv_hits = uv_hits .astype(bool),
|
||||
uv_miss = uv_miss .astype(bool),
|
||||
cam_pos = cam_pos[:3] .astype(dtype),
|
||||
cam_mat4 = cam_mat4 .astype(dtype),
|
||||
proj_mat4 = scan.projection_matrix .astype(dtype),
|
||||
transforms = {k:v.astype(dtype) for k, v in transforms.items()},
|
||||
)
|
||||
|
||||
def sample_sphere_view_scan_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
sphere_points : int = 4000, # resulting rays are n*(n-1)
|
||||
compute_normals : bool = False,
|
||||
no_filter_backhits : bool = False,
|
||||
no_unit_sphere : bool = False,
|
||||
no_permute : bool = False,
|
||||
dtype : type = np.float32,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
translation, scale = compute_unit_sphere_transform(mesh, dtype=dtype)
|
||||
|
||||
# get unit-sphere points, then transform to model space
|
||||
two_sphere = generate_equidistant_sphere_rays(sphere_points, **kw).astype(dtype) # (n*(n-1), 2, 3)
|
||||
two_sphere = two_sphere / scale - translation # we transform after cache lookup
|
||||
|
||||
if mesh.ray.__class__.__module__.split(".")[-1] != "ray_pyembree":
|
||||
warnings.warn("Pyembree not found, the ray-tracing will be SLOW!")
|
||||
|
||||
(
|
||||
locations,
|
||||
index_ray,
|
||||
index_tri,
|
||||
) = mesh.ray.intersects_location(
|
||||
two_sphere[:, 0, :],
|
||||
two_sphere[:, 1, :] - two_sphere[:, 0, :], # direction, not target coordinate
|
||||
multiple_hits=False,
|
||||
)
|
||||
|
||||
|
||||
if compute_normals:
|
||||
location_normals = mesh.face_normals[index_tri]
|
||||
|
||||
batch = two_sphere.shape[:1]
|
||||
hits = np.zeros((*batch,), dtype=np.bool)
|
||||
miss = np.ones((*batch,), dtype=np.bool)
|
||||
cam_pos = two_sphere[:, 0, :]
|
||||
intersections = two_sphere[:, 1, :] # far-plane, effectively
|
||||
normals = np.zeros((*batch, 3), dtype=dtype)
|
||||
|
||||
index_ray_front = index_ray
|
||||
if not no_filter_backhits:
|
||||
if not compute_normals:
|
||||
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
|
||||
mask = ((intersections[index_ray] - cam_pos[index_ray]) * location_normals).sum(axis=-1) <= 0
|
||||
index_ray_front = index_ray[mask]
|
||||
|
||||
|
||||
hits[index_ray_front] = True
|
||||
miss[index_ray] = False
|
||||
intersections[index_ray] = locations
|
||||
normals[index_ray] = location_normals
|
||||
|
||||
|
||||
if not no_permute:
|
||||
assert len(batch) == 1, batch
|
||||
permutation = np.random.permutation(*batch)
|
||||
hits = hits [permutation]
|
||||
miss = miss [permutation]
|
||||
intersections = intersections[permutation, :]
|
||||
normals = normals [permutation, :]
|
||||
cam_pos = cam_pos [permutation, :]
|
||||
|
||||
# apply unit sphere transform
|
||||
if not no_unit_sphere:
|
||||
intersections = (intersections + translation) * scale
|
||||
cam_pos = (cam_pos + translation) * scale
|
||||
|
||||
return SingleViewUVScan(
|
||||
hits = hits,
|
||||
miss = miss,
|
||||
points = intersections,
|
||||
normals = normals,
|
||||
colors = None, # colors
|
||||
distances = None,
|
||||
cam_pos = cam_pos,
|
||||
cam_mat4 = None,
|
||||
proj_mat4 = None,
|
||||
transforms = {},
|
||||
)
|
||||
|
||||
def distance_from_rays_to_point_cloud(
|
||||
ray_origins : np.ndarray, # (*A, 3)
|
||||
ray_dirs : np.ndarray, # (*A, 3)
|
||||
points : np.ndarray, # (*B, 3)
|
||||
dirs_normalized : bool = False,
|
||||
n_steps : int = 40,
|
||||
) -> np.ndarray: # (A)
|
||||
|
||||
# anything outside of this volume will never constribute to the result
|
||||
max_norm = max(
|
||||
np.linalg.norm(ray_origins, axis=-1).max(),
|
||||
np.linalg.norm(points, axis=-1).max(),
|
||||
) * 1.02
|
||||
|
||||
if not dirs_normalized:
|
||||
ray_dirs = ray_dirs / np.linalg.norm(ray_dirs, axis=-1)[..., None]
|
||||
|
||||
|
||||
# deal with single-view clouds
|
||||
if ray_origins.shape != ray_dirs.shape:
|
||||
ray_origins = np.broadcast_to(ray_origins, ray_dirs.shape)
|
||||
|
||||
n_points = np.product(points.shape[:-1])
|
||||
use_faiss = n_points > 160000*4
|
||||
if not use_faiss:
|
||||
index = BallTree(points)
|
||||
else:
|
||||
# http://ann-benchmarks.com/index.html
|
||||
assert np.issubdtype(points.dtype, np.float32)
|
||||
assert np.issubdtype(ray_origins.dtype, np.float32)
|
||||
assert np.issubdtype(ray_dirs.dtype, np.float32)
|
||||
index = faiss.index_factory(points.shape[-1], "NSG32,Flat") # https://github.com/facebookresearch/faiss/wiki/The-index-factory
|
||||
|
||||
index.nprobe = 5 # 10 # default is 1
|
||||
index.train(points)
|
||||
index.add(points)
|
||||
|
||||
if not use_faiss:
|
||||
min_d, min_n = index.query(ray_origins, k=1, return_distance=True)
|
||||
else:
|
||||
min_d, min_n = index.search(ray_origins, k=1)
|
||||
min_d = np.sqrt(min_d)
|
||||
acc_d = min_d.copy()
|
||||
|
||||
for step in range(1, n_steps+1):
|
||||
query_points = ray_origins + acc_d * ray_dirs
|
||||
if max_norm is not None:
|
||||
qmask = np.linalg.norm(query_points, axis=-1) < max_norm
|
||||
if not qmask.any(): break
|
||||
query_points = query_points[qmask]
|
||||
else:
|
||||
qmask = slice(None)
|
||||
if not use_faiss:
|
||||
current_d, current_n = index.query(query_points, k=1, return_distance=True)
|
||||
else:
|
||||
current_d, current_n = index.search(query_points, k=1)
|
||||
current_d = np.sqrt(current_d)
|
||||
if max_norm is not None:
|
||||
min_d[qmask] = np.minimum(current_d, min_d[qmask])
|
||||
new_min_mask = min_d[qmask] == current_d
|
||||
qmask2 = qmask.copy()
|
||||
qmask2[qmask2] = new_min_mask[..., 0]
|
||||
min_n[qmask2] = current_n[new_min_mask[..., 0]]
|
||||
acc_d[qmask] += current_d * 0.25
|
||||
else:
|
||||
np.minimum(current_d, min_d, out=min_d)
|
||||
new_min_mask = min_d == current_d
|
||||
min_n[new_min_mask] = current_n[new_min_mask]
|
||||
acc_d += current_d * 0.25
|
||||
|
||||
closest_points = points[min_n[:, 0], :] # k=1
|
||||
distances = np.linalg.norm(np.cross(closest_points - ray_origins, ray_dirs, axis=-1), axis=-1)
|
||||
return distances
|
||||
|
||||
# helpers
|
||||
|
||||
@compose(np.array) # make copy to avoid lru cache mutation
|
||||
@lru_cache(maxsize=1)
|
||||
def generate_equidistant_sphere_rays(n : int, **kw) -> np.ndarray: # output (n*n(-1)) rays, n may be off
|
||||
sphere_points = points.generate_equidistant_sphere_points(n=n, **kw)
|
||||
|
||||
indices = np.indices((len(sphere_points),))[0] # (N)
|
||||
# cartesian product
|
||||
cprod = np.transpose([np.tile(indices, len(indices)), np.repeat(indices, len(indices))]) # (N**2, 2)
|
||||
# filter repeated combinations
|
||||
permutations = cprod[cprod[:, 0] != cprod[:, 1], :] # (N*(N-1), 2)
|
||||
# lookup sphere points
|
||||
two_sphere = sphere_points[permutations, :] # (N*(N-1), 2, 3)
|
||||
|
||||
return two_sphere
|
||||
|
||||
def compute_unit_sphere_transform(mesh: Trimesh, *, dtype=type) -> tuple[np.ndarray, float]:
|
||||
"""
|
||||
returns translation and scale which mesh_to_sdf applies to meshes before computing their SDF cloud
|
||||
"""
|
||||
# the transformation applied by mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
translation = -mesh.bounding_box.centroid
|
||||
scale = 1 / np.max(np.linalg.norm(mesh.vertices + translation, axis=1))
|
||||
if dtype is not None:
|
||||
translation = translation.astype(dtype)
|
||||
scale = scale .astype(dtype)
|
||||
return translation, scale
|
||||
6
ifield/data/common/types.py
Normal file
6
ifield/data/common/types.py
Normal file
@@ -0,0 +1,6 @@
|
||||
__doc__ = """
|
||||
Some helper types.
|
||||
"""
|
||||
|
||||
class MalformedMesh(Exception):
|
||||
pass
|
||||
Reference in New Issue
Block a user