This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

3
ifield/data/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
__doc__ = """
Submodules to read and process datasets
"""

View File

View File

@@ -0,0 +1,90 @@
from ...utils.helpers import make_relative
from pathlib import Path
from tqdm import tqdm
from typing import Union, Optional
import io
import os
import json
import requests
PathLike = Union[os.PathLike, str]
__doc__ = """
Here are some helper functions for processing data.
"""
def check_url(url): # HTTP HEAD
return requests.head(url).ok
def download_stream(
url : str,
file_object,
block_size : int = 1024,
silent : bool = False,
label : Optional[str] = None,
):
resp = requests.get(url, stream=True)
total_size = int(resp.headers.get("content-length", 0))
if not silent:
progress_bar = tqdm(total=total_size , unit="iB", unit_scale=True, desc=label)
for chunk in resp.iter_content(block_size):
if not silent:
progress_bar.update(len(chunk))
file_object.write(chunk)
if not silent:
progress_bar.close()
if total_size != 0 and progress_bar.n != total_size:
print("ERROR, something went wrong")
def download_data(
url : str,
block_size : int = 1024,
silent : bool = False,
label : Optional[str] = None,
) -> bytearray:
f = io.BytesIO()
download_stream(url, f, block_size=block_size, silent=silent, label=label)
f.seek(0)
return bytearray(f.read())
def download_file(
url : str,
fname : Union[Path, str],
block_size : int = 1024,
silent = False,
):
if not isinstance(fname, Path):
fname = Path(fname)
with fname.open("wb") as f:
download_stream(url, f, block_size=block_size, silent=silent, label=make_relative(fname, Path.cwd()).name)
def is_downloaded(
target_dir : PathLike,
url : str,
*,
add : bool = False,
dbfiles : Union[list[PathLike], PathLike],
):
if not isinstance(target_dir, os.PathLike):
target_dir = Path(target_dir)
if not isinstance(dbfiles, list):
dbfiles = [dbfiles]
if not dbfiles:
raise ValueError("'dbfiles' empty")
downloaded = set()
for dbfile_fname in dbfiles:
dbfile_fname = target_dir / dbfile_fname
if dbfile_fname.is_file():
with open(dbfile_fname, "r") as f:
downloaded.update(json.load(f)["downloaded"])
if add and url not in downloaded:
downloaded.add(url)
with open(dbfiles[0], "w") as f:
data = {"downloaded": sorted(downloaded)}
json.dump(data, f, indent=2, sort_keys=True)
return True
return url in downloaded

View File

@@ -0,0 +1,370 @@
#!/usr/bin/env python3
from abc import abstractmethod, ABCMeta
from collections import namedtuple
from pathlib import Path
import copy
import dataclasses
import functools
import h5py as h5
import hdf5plugin
import numpy as np
import operator
import os
import sys
import typing
__all__ = [
"DataclassMeta",
"Dataclass",
"H5Dataclass",
"H5Array",
"H5ArrayNoSlice",
]
T = typing.TypeVar("T")
NoneType = type(None)
PathLike = typing.Union[os.PathLike, str]
H5Array = typing._alias(np.ndarray, 0, inst=False, name="H5Array")
H5ArrayNoSlice = typing._alias(np.ndarray, 0, inst=False, name="H5ArrayNoSlice")
DataclassField = namedtuple("DataclassField", [
"name",
"type",
"is_optional",
"is_array",
"is_sliceable",
"is_prefix",
])
def strip_optional(val: type) -> type:
if typing.get_origin(val) is typing.Union:
union = set(typing.get_args(val))
if len(union - {NoneType}) == 1:
val, = union - {NoneType}
else:
raise TypeError(f"Non-'typing.Optional' 'typing.Union' is not supported: {typing._type_repr(val)!r}")
return val
def is_array(val, *, _inner=False):
"""
Hacky way to check if a value or type is an array.
The hack omits having to depend on large frameworks such as pytorch or pandas
"""
val = strip_optional(val)
if val is H5Array or val is H5ArrayNoSlice:
return True
if typing._type_repr(val) in (
"numpy.ndarray",
"torch.Tensor",
):
return True
if not _inner:
return is_array(type(val), _inner=True)
return False
def prod(numbers: typing.Iterable[T], initial: typing.Optional[T] = None) -> T:
if initial is not None:
return functools.reduce(operator.mul, numbers, initial)
else:
return functools.reduce(operator.mul, numbers)
class DataclassMeta(type):
def __new__(
mcls,
name : str,
bases : tuple[type, ...],
attrs : dict[str, typing.Any],
**kwargs,
):
cls = super().__new__(mcls, name, bases, attrs, **kwargs)
if sys.version_info[:2] >= (3, 10) and not hasattr(cls, "__slots__"):
cls = dataclasses.dataclass(slots=True)(cls)
else:
cls = dataclasses.dataclass(cls)
return cls
class DataclassABCMeta(DataclassMeta, ABCMeta):
pass
class Dataclass(metaclass=DataclassMeta):
def __getitem__(self, key: str) -> typing.Any:
if key in self.keys():
return getattr(self, key)
raise KeyError(key)
def __setitem__(self, key: str, value: typing.Any):
if key in self.keys():
return setattr(self, key, value)
raise KeyError(key)
def keys(self) -> typing.KeysView:
return self.as_dict().keys()
def values(self) -> typing.ValuesView:
return self.as_dict().values()
def items(self) -> typing.ItemsView:
return self.as_dict().items()
def as_dict(self, properties_to_include: set[str] = None, **kw) -> dict[str, typing.Any]:
out = dataclasses.asdict(self, **kw)
for name in (properties_to_include or []):
out[name] = getattr(self, name)
return out
def as_tuple(self, properties_to_include: list[str]) -> tuple:
out = dataclasses.astuple(self)
if not properties_to_include:
return out
else:
return (
*out,
*(getattr(self, name) for name in properties_to_include),
)
def copy(self: T, *, deep=True) -> T:
return (copy.deepcopy if deep else copy.copy)(self)
class H5Dataclass(Dataclass):
# settable with class params:
_prefix : str = dataclasses.field(init=False, repr=False, default="")
_n_pages : int = dataclasses.field(init=False, repr=False, default=10)
_require_all : bool = dataclasses.field(init=False, repr=False, default=False)
def __init_subclass__(cls,
prefix : typing.Optional[str] = None,
n_pages : typing.Optional[int] = None,
require_all : typing.Optional[bool] = None,
**kw,
):
super().__init_subclass__(**kw)
assert dataclasses.is_dataclass(cls)
if prefix is not None: cls._prefix = prefix
if n_pages is not None: cls._n_pages = n_pages
if require_all is not None: cls._require_all = require_all
@classmethod
def _get_fields(cls) -> typing.Iterable[DataclassField]:
for field in dataclasses.fields(cls):
if not field.init:
continue
assert field.name not in ("_prefix", "_n_pages", "_require_all"), (
f"{field.name!r} can not be in {cls.__qualname__}.__init__.\n"
"Set it with dataclasses.field(default=YOUR_VALUE, init=False, repr=False)"
)
if isinstance(field.type, str):
raise TypeError("Type hints are strings, perhaps avoid using `from __future__ import annotations`")
type_inner = strip_optional(field.type)
is_prefix = typing.get_origin(type_inner) is dict and typing.get_args(type_inner)[:1] == (str,)
field_type = typing.get_args(type_inner)[1] if is_prefix else field.type
if field.default is None or typing.get_origin(field.type) is typing.Union and NoneType in typing.get_args(field.type):
field_type = typing.Optional[field_type]
yield DataclassField(
name = field.name,
type = strip_optional(field_type),
is_optional = typing.get_origin(field_type) is typing.Union and NoneType in typing.get_args(field_type),
is_array = is_array(field_type),
is_sliceable = is_array(field_type) and strip_optional(field_type) is not H5ArrayNoSlice,
is_prefix = is_prefix,
)
@classmethod
def from_h5_file(cls : type[T],
fname : typing.Union[PathLike, str],
*,
page : typing.Optional[int] = None,
n_pages : typing.Optional[int] = None,
read_slice : slice = slice(None),
require_even_pages : bool = True,
) -> T:
if not isinstance(fname, Path):
fname = Path(fname)
if n_pages is None:
n_pages = cls._n_pages
if not fname.exists():
raise FileNotFoundError(str(fname))
if not h5.is_hdf5(fname):
raise TypeError(f"Not a HDF5 file: {str(fname)!r}")
# if this class has no fields, print a example class:
if not any(field.init for field in dataclasses.fields(cls)):
with h5.File(fname, "r") as f:
klen = max(map(len, f.keys()))
example_cls = f"\nclass {cls.__name__}(Dataclass, require_all=True):\n" + "\n".join(
f" {k.ljust(klen)} : "
+ (
"H5Array" if prod(v.shape, 1) > 1 else (
"float" if issubclass(v.dtype.type, np.floating) else (
"int" if issubclass(v.dtype.type, np.integer) else (
"bool" if issubclass(v.dtype.type, np.bool_) else (
"typing.Any"
))))).ljust(14 + 1)
+ f" #{repr(v).split(':', 1)[1].removesuffix('>')}"
for k, v in f.items()
)
raise NotImplementedError(f"{cls!r} has no fields!\nPerhaps try the following:{example_cls}")
fields_consumed = set()
def make_kwarg(
file : h5.File,
keys : typing.KeysView,
field : DataclassField,
) -> tuple[str, typing.Any]:
if field.is_optional:
if field.name not in keys:
return field.name, None
if field.is_sliceable:
if page is not None:
n_items = int(f[cls._prefix + field.name].shape[0])
page_len = n_items // n_pages
modulus = n_items % n_pages
if modulus: page_len += 1 # round up
if require_even_pages and modulus:
raise ValueError(f"Field {field.name!r} {tuple(f[cls._prefix + field.name].shape)} is not cleanly divisible into {n_pages} pages")
this_slice = slice(
start = page_len * page,
stop = page_len * (page+1),
step = read_slice.step, # inherit step
)
else:
this_slice = read_slice
else:
this_slice = slice(None) # read all
# array or scalar?
def read_dataset(var):
# https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
if field.is_array:
return var[this_slice]
if var.shape == (1,):
return var[0]
else:
return var[()]
if field.is_prefix:
fields_consumed.update(
key
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
)
return field.name, {
key.removeprefix(f"{cls._prefix}{field.name}_") : read_dataset(file[key])
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
}
else:
fields_consumed.add(cls._prefix + field.name)
return field.name, read_dataset(file[cls._prefix + field.name])
with h5.File(fname, "r") as f:
keys = f.keys()
init_dict = dict( make_kwarg(f, keys, i) for i in cls._get_fields() )
try:
out = cls(**init_dict)
except Exception as e:
class_attrs = set(field.name for field in dataclasses.fields(cls))
file_attr = set(init_dict.keys())
raise e.__class__(f"{e}. {class_attrs=}, {file_attr=}, diff={class_attrs.symmetric_difference(file_attr)}") from e
if cls._require_all:
fields_not_consumed = set(keys) - fields_consumed
if fields_not_consumed:
raise ValueError(f"Not all HDF5 fields consumed: {fields_not_consumed!r}")
return out
def to_h5_file(self,
fname : PathLike,
mkdir : bool = False,
):
if not isinstance(fname, Path):
fname = Path(fname)
if not fname.parent.is_dir():
if mkdir:
fname.parent.mkdir(parents=True)
else:
raise NotADirectoryError(fname.parent)
with h5.File(fname, "w") as f:
for field in type(self)._get_fields():
if field.is_optional and getattr(self, field.name) is None:
continue
value = getattr(self, field.name)
if field.is_array:
if any(type(i) is not np.ndarray for i in (value.values() if field.is_prefix else [value])):
raise TypeError(
"When dumping a H5Dataclass, make sure the array fields are "
f"numpy arrays (the type of {field.name!r} is {typing._type_repr(type(value))}).\n"
"Example: h5dataclass.map_arrays(torch.Tensor.numpy)"
)
else:
pass
def write_value(key: str, value: typing.Any):
if field.is_array:
f.create_dataset(key, data=value, **hdf5plugin.LZ4())
else:
f.create_dataset(key, data=value)
if field.is_prefix:
for k, v in value.items():
write_value(self._prefix + field.name + "_" + k, v)
else:
write_value(self._prefix + field.name, value)
def map_arrays(self: T, func: typing.Callable[[H5Array], H5Array], do_copy: bool = False) -> T:
if do_copy: # shallow
self = self.copy(deep=False)
for field in type(self)._get_fields():
if field.is_optional and getattr(self, field.name) is None:
continue
if field.is_prefix and field.is_array:
setattr(self, field.name, {
k : func(v)
for k, v in getattr(self, field.name).items()
})
elif field.is_array:
setattr(self, field.name, func(getattr(self, field.name)))
return self
def astype(self: T, t: type, do_copy: bool = False, convert_nonfloats: bool = False) -> T:
return self.map_arrays(lambda x: x.astype(t) if convert_nonfloats or not np.issubdtype(x.dtype, int) else x)
def copy(self: T, *, deep=True) -> T:
out = super().copy(deep=deep)
if not deep:
for field in type(self)._get_fields():
if field.is_prefix:
out[field.name] = copy.copy(field.name)
return out
@property
def shape(self) -> dict[str, tuple[int, ...]]:
return {
key: value.shape
for key, value in self.items()
if hasattr(value, "shape")
}
class TransformableDataclassMixin(metaclass=DataclassABCMeta):
@abstractmethod
def transform(self: T, mat4: np.ndarray, inplace=False) -> T:
...
def transform_to(self: T, name: str, inverse_name: str = None, *, inplace=False) -> T:
mtx = self.transforms[name]
out = self.transform(mtx, inplace=inplace)
out.transforms.pop(name) # consumed
inv = np.linalg.inv(mtx)
for key in list(out.transforms.keys()): # maintain the other transforms
out.transforms[key] = out.transforms[key] @ inv
if inverse_name is not None: # store inverse
out.transforms[inverse_name] = inv
return out

View File

@@ -0,0 +1,48 @@
from math import pi
from trimesh import Trimesh
import numpy as np
import os
import trimesh
import trimesh.transformations as T
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
__doc__ = """
Here are some helper functions for processing data.
"""
def rotate_to_closest_axis_aligned_bounds(
mesh : Trimesh,
order_axes : bool = True,
fail_ok : bool = True,
) -> np.ndarray:
to_origin_mat4, extents = trimesh.bounds.oriented_bounds(mesh, ordered=not order_axes)
to_aabb_rot_mat4 = T.euler_matrix(*T.decompose_matrix(to_origin_mat4)[3])
if not order_axes:
return to_aabb_rot_mat4
v = pi / 4 * 1.01 # tolerance
v2 = pi / 2
faces = (
(0, 0),
(1, 0),
(2, 0),
(3, 0),
(0, 1),
(0,-1),
)
orientations = [ # 6 faces x 4 rotations per face
(f[0] * v2, f[1] * v2, i * v2)
for i in range(4)
for f in faces]
for x, y, z in orientations:
mat4 = T.euler_matrix(x, y, z) @ to_aabb_rot_mat4
ai, aj, ak = T.euler_from_matrix(mat4)
if abs(ai) <= v and abs(aj) <= v and abs(ak) <= v:
return mat4
if fail_ok: return to_aabb_rot_mat4
raise Exception("Unable to orient mesh")

View File

@@ -0,0 +1,297 @@
from __future__ import annotations
from ...utils.helpers import compose
from functools import reduce, lru_cache
from math import ceil
from typing import Iterable
import numpy as np
import operator
__doc__ = """
Here are some helper functions for processing data.
"""
def img2col(img: np.ndarray, psize: int) -> np.ndarray:
# based of ycb_generate_point_cloud.py provided by YCB
n_channels = 1 if len(img.shape) == 2 else img.shape[0]
n_channels, rows, cols = (1,) * (3 - len(img.shape)) + img.shape
# pad the image
img_pad = np.zeros((
n_channels,
int(ceil(1.0 * rows / psize) * psize),
int(ceil(1.0 * cols / psize) * psize),
))
img_pad[:, 0:rows, 0:cols] = img
# allocate output buffer
final = np.zeros((
img_pad.shape[1],
img_pad.shape[2],
n_channels,
psize,
psize,
))
for c in range(n_channels):
for x in range(psize):
for y in range(psize):
img_shift = np.vstack((
img_pad[c, x:],
img_pad[c, :x]))
img_shift = np.column_stack((
img_shift[:, y:],
img_shift[:, :y]))
final[x::psize, y::psize, c] = np.swapaxes(
img_shift.reshape(
int(img_pad.shape[1] / psize), psize,
int(img_pad.shape[2] / psize), psize),
1,
2)
# crop output and unwrap axes with size==1
return np.squeeze(final[
0:rows - psize + 1,
0:cols - psize + 1])
def filter_depth_discontinuities(depth_map: np.ndarray, filt_size = 7, thresh = 1000) -> np.ndarray:
"""
Removes data close to discontinuities, with size filt_size.
"""
# based of ycb_generate_point_cloud.py provided by YCB
# Ensure that filter sizes are okay
assert filt_size % 2, "Can only use odd filter sizes."
# Compute discontinuities
offset = int(filt_size - 1) // 2
patches = 1.0 * img2col(depth_map, filt_size)
mids = patches[:, :, offset, offset]
mins = np.min(patches, axis=(2, 3))
maxes = np.max(patches, axis=(2, 3))
discont = np.maximum(
np.abs(mins - mids),
np.abs(maxes - mids))
mark = discont > thresh
# Account for offsets
final_mark = np.zeros(depth_map.shape, dtype=np.uint16)
final_mark[offset:offset + mark.shape[0],
offset:offset + mark.shape[1]] = mark
return depth_map * (1 - final_mark)
def reorient_depth_map(
depth_map : np.ndarray,
rgb_map : np.ndarray,
depth_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
depth_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
rgb_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
rgb_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
ir_to_rgb_mat4 : np.ndarray, # extrinsic transformation matrix from depth to rgb camera viewpoint
rgb_mask_map : np.ndarray = None,
_output_points = False, # retval (H, W) if false else (N, XYZRGB)
_output_hits_uvs = False, # retval[1] is dtype=bool of hits shaped like depth_map
) -> np.ndarray:
"""
Corrects depth_map to be from the same view as the rgb_map, with the same dimensions.
If _output_points is True, the points returned are in the rgb camera space.
"""
# based of ycb_generate_point_cloud.py provided by YCB
# now faster AND more easy on the GIL
height_old, width_old, *_ = depth_map.shape
height, width, *_ = rgb_map.shape
d_cx, r_cx = depth_mat3[0, 2], rgb_mat3[0, 2] # optical center
d_cy, r_cy = depth_mat3[1, 2], rgb_mat3[1, 2]
d_fx, r_fx = depth_mat3[0, 0], rgb_mat3[0, 0] # focal length
d_fy, r_fy = depth_mat3[1, 1], rgb_mat3[1, 1]
d_k1, d_k2, d_p1, d_p2, d_k3 = depth_vec5
c_k1, c_k2, c_p1, c_p2, c_k3 = rgb_vec5
# make a UV grid over depth_map
u, v = np.meshgrid(
np.arange(width_old),
np.arange(height_old),
)
# compute xyz coordinates for all depths
xyz_depth = np.stack((
(u - d_cx) / d_fx,
(v - d_cy) / d_fy,
depth_map,
np.ones(depth_map.shape)
)).reshape((4, -1))
xyz_depth = xyz_depth[:, xyz_depth[2] != 0]
# undistort depth coordinates
d_x, d_y = xyz_depth[:2]
r = np.linalg.norm(xyz_depth[:2], axis=0)
xyz_depth[0, :] \
= d_x / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
- (2*d_p1*d_x*d_y + d_p2*(r**2 + 2*d_x**2))
xyz_depth[1, :] \
= d_y / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
- (d_p1*(r**2 + 2*d_y**2) + 2*d_p2*d_x*d_y)
# unproject x and y
xyz_depth[0, :] *= xyz_depth[2, :]
xyz_depth[1, :] *= xyz_depth[2, :]
# convert depths to RGB camera viewpoint
xyz_rgb = ir_to_rgb_mat4 @ xyz_depth
# project depths to RGB canvas
rgb_z_inv = 1 / xyz_rgb[2] # perspective correction
rgb_uv = np.stack((
xyz_rgb[0] * rgb_z_inv * r_fx + r_cx + 0.5,
xyz_rgb[1] * rgb_z_inv * r_fy + r_cy + 0.5,
)).astype(np.int)
# mask of the rgb_xyz values within view of rgb_map
mask = reduce(operator.and_, [
rgb_uv[0] >= 0,
rgb_uv[1] >= 0,
rgb_uv[0] < width,
rgb_uv[1] < height,
])
if rgb_mask_map is not None:
mask[mask] &= rgb_mask_map[
rgb_uv[1, mask],
rgb_uv[0, mask]]
if not _output_points: # output image
output = np.zeros((height, width), dtype=depth_map.dtype)
output[
rgb_uv[1, mask],
rgb_uv[0, mask],
] = xyz_rgb[2, mask]
else: # output pointcloud
rgbs = rgb_map[ # lookup rgb values using rgb_uv
rgb_uv[1, mask],
rgb_uv[0, mask]]
output = np.stack((
xyz_rgb[0, mask], # x
xyz_rgb[1, mask], # y
xyz_rgb[2, mask], # z
rgbs[:, 0], # r
rgbs[:, 1], # g
rgbs[:, 2], # b
)).T
# output for realsies
if not _output_hits_uvs: #raw
return output
else: # with hit mask
uv = np.zeros((height, width), dtype=bool)
# filter points overlapping in the depth map
uv_indices = (
rgb_uv[1, mask],
rgb_uv[0, mask],
)
_, chosen = np.unique( uv_indices[0] << 32 | uv_indices[1], return_index=True )
output = output[chosen, :]
uv[uv_indices] = True
return output, uv
def join_rgb_and_depth_to_points(*a, **kw) -> np.ndarray:
return reorient_depth_map(*a, _output_points=True, **kw)
@compose(np.array) # block lru cache mutation
@lru_cache(maxsize=1)
@compose(list)
def generate_equidistant_sphere_points(
n : int,
centroid : np.ndarray = (0, 0, 0),
radius : float = 1,
compute_sphere_coordinates : bool = False,
compute_normals : bool = False,
shift_theta : bool = False,
) -> Iterable[tuple[float, ...]]:
# Deserno M. How to generate equidistributed points on the surface of a sphere
# https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf
if compute_sphere_coordinates and compute_normals:
raise ValueError(
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
)
n_count = 0
a = 4 * np.pi / n
d = np.sqrt(a)
n_theta = round(np.pi / d)
d_theta = np.pi / n_theta
d_phi = a / d_theta
for i in range(0, n_theta):
theta = np.pi * (i + 0.5) / n_theta
n_phi = round(2 * np.pi * np.sin(theta) / d_phi)
for j in range(0, n_phi):
phi = 2 * np.pi * j / n_phi
if compute_sphere_coordinates: # (theta, phi)
yield (
theta if shift_theta else theta - 0.5*np.pi,
phi,
)
elif compute_normals: # (x, y, z, nx, ny, nz)
yield (
centroid[0] + radius * np.sin(theta) * np.cos(phi),
centroid[1] + radius * np.sin(theta) * np.sin(phi),
centroid[2] + radius * np.cos(theta),
np.sin(theta) * np.cos(phi),
np.sin(theta) * np.sin(phi),
np.cos(theta),
)
else: # (x, y, z)
yield (
centroid[0] + radius * np.sin(theta) * np.cos(phi),
centroid[1] + radius * np.sin(theta) * np.sin(phi),
centroid[2] + radius * np.cos(theta),
)
n_count += 1
def generate_random_sphere_points(
n : int,
centroid : np.ndarray = (0, 0, 0),
radius : float = 1,
compute_sphere_coordinates : bool = False,
compute_normals : bool = False,
shift_theta : bool = False, # depends on convention
) -> np.ndarray:
if compute_sphere_coordinates and compute_normals:
raise ValueError(
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
)
theta = np.arcsin(np.random.uniform(-1, 1, n)) # inverse transform sampling
phi = np.random.uniform(0, 2*np.pi, n)
if compute_sphere_coordinates: # (theta, phi)
return np.stack((
theta if not shift_theta else 0.5*np.pi + theta,
phi,
), axis=1)
elif compute_normals: # (x, y, z, nx, ny, nz)
return np.stack((
centroid[0] + radius * np.cos(theta) * np.cos(phi),
centroid[1] + radius * np.cos(theta) * np.sin(phi),
centroid[2] + radius * np.sin(theta),
np.cos(theta) * np.cos(phi),
np.cos(theta) * np.sin(phi),
np.sin(theta),
), axis=1)
else: # (x, y, z)
return np.stack((
centroid[0] + radius * np.cos(theta) * np.cos(phi),
centroid[1] + radius * np.cos(theta) * np.sin(phi),
centroid[2] + radius * np.sin(theta),
), axis=1)

View File

@@ -0,0 +1,85 @@
from .h5_dataclasses import H5Dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Hashable, Optional, Callable
import os
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
__doc__ = """
Here are some helper functions for processing data.
"""
# multiprocessing does not work due to my rediculous use of closures, which seemingly cannot be pickled
# paralelize it in the shell instead
def precompute_data(
computer : Callable[[Hashable], Optional[H5Dataclass]],
identifiers : list[Hashable],
output_paths : list[Path],
page : tuple[int, int] = (0, 1),
*,
force : bool = False,
debug : bool = False,
):
"""
precomputes data and stores them as HDF5 datasets using `.to_file(path: Path)`
"""
page, n_pages = page
assert len(identifiers) == len(output_paths)
total = len(identifiers)
identifier_max_len = max(map(len, map(str, identifiers)))
t_epoch = None
def log(state: str, is_start = False):
nonlocal t_epoch
if is_start: t_epoch = datetime.now()
td = timedelta(0) if is_start else datetime.now() - t_epoch
print(" - "
f"{str(index+1).rjust(len(str(total)))}/{total}: "
f"{str(identifier).ljust(identifier_max_len)} @ {td}: {state}"
)
print(f"precompute_data(computer={computer.__module__}.{computer.__qualname__}, identifiers=..., force={force}, page={page})")
t_begin = datetime.now()
failed = []
# pagination
page_size = total // n_pages + bool(total % n_pages)
jobs = list(zip(identifiers, output_paths))[page_size*page : page_size*(page+1)]
for index, (identifier, output_path) in enumerate(jobs, start=page_size*page):
if not force and output_path.exists() and output_path.stat().st_size > 0:
continue
log("compute", is_start=True)
# compute
try:
res = computer(identifier)
except Exception as e:
failed.append(identifier)
log(f"failed compute: {e.__class__.__name__}: {e}")
if DEBUG or debug: raise e
continue
if res is None:
failed.append(identifier)
log("no result")
continue
# write to file
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
res.to_h5_file(output_path)
except Exception as e:
failed.append(identifier)
log(f"failed write: {e.__class__.__name__}: {e}")
if output_path.is_file(): output_path.unlink() # cleanup
if DEBUG or debug: raise e
continue
log("done")
print("precompute_data finished in", datetime.now() - t_begin)
print("failed:", failed or None)

768
ifield/data/common/scan.py Normal file
View File

@@ -0,0 +1,768 @@
from ...utils.helpers import compose
from . import points
from .h5_dataclasses import H5Dataclass, H5Array, H5ArrayNoSlice, TransformableDataclassMixin
from methodtools import lru_cache
from sklearn.neighbors import BallTree
import faiss
from trimesh import Trimesh
from typing import Iterable
from typing import Optional, TypeVar
import mesh_to_sdf
import mesh_to_sdf.scan as sdf_scan
import numpy as np
import trimesh
import trimesh.transformations as T
import warnings
__doc__ = """
Here are some helper types for data.
"""
_T = TypeVar("T")
class InvalidateLRUOnWriteMixin:
def __setattr__(self, key, value):
if not key.startswith("__wire|"):
for attr in dir(self):
if attr.startswith("__wire|"):
getattr(self, attr).cache_clear()
return super().__setattr__(key, value)
def lru_property(func):
return lru_cache(maxsize=1)(property(func))
class SingleViewScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
points_hit : H5ArrayNoSlice # (N, 3)
normals_hit : Optional[H5ArrayNoSlice] # (N, 3)
points_miss : H5ArrayNoSlice # (M, 3)
distances_miss : Optional[H5ArrayNoSlice] # (M)
colors_hit : Optional[H5ArrayNoSlice] # (N, 3)
colors_miss : Optional[H5ArrayNoSlice] # (M, 3)
uv_hits : Optional[H5ArrayNoSlice] # (H, W) dtype=bool
uv_miss : Optional[H5ArrayNoSlice] # (H, W) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backfaces)
cam_pos : H5ArrayNoSlice # (3)
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
out = self if inplace else self.copy(deep=False)
out.points_hit = T.transform_points(self.points_hit, mat4)
out.normals_hit = T.transform_points(self.normals_hit, mat4) if self.normals_hit is not None else None
out.points_miss = T.transform_points(self.points_miss, mat4)
out.distances_miss = self.distances_miss * scale_xyz
out.cam_pos = T.transform_points(self.points_cam, mat4)[-1]
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
return out
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
assert not self.has_miss_distances
if not self.is_hitting:
raise ValueError("No hits to compute the ray distance towards")
out = self.copy(deep=deep) if copy else self
out.distances_miss \
= distance_from_rays_to_point_cloud(
ray_origins = out.points_cam,
ray_dirs = out.ray_dirs_miss,
points = out.points_hit,
).astype(out.points_cam.dtype)
return out
@lru_property
def points(self) -> np.ndarray: # (N+M+1, 3)
return np.concatenate((
self.points_hit,
self.points_miss,
self.points_cam,
))
@lru_property
def uv_points(self) -> np.ndarray: # (N+M+1, 3)
if not self.has_uv: raise ValueError
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.points_hit.dtype)
out[self.uv_hits, :] = self.points_hit
out[self.uv_miss, :] = self.points_miss
return out
@lru_property
def uv_normals(self) -> np.ndarray: # (N+M+1, 3)
if not self.has_uv: raise ValueError
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.normals_hit.dtype)
out[self.uv_hits, :] = self.normals_hit
return out
@lru_property
def points_cam(self) -> Optional[np.ndarray]: # (1, 3)
if self.cam_pos is None: return None
return self.cam_pos[None, :]
@lru_property
def points_hit_centroid(self) -> np.ndarray:
return self.points_hit.mean(axis=0)
@lru_property
def points_hit_std(self) -> np.ndarray:
return self.points_hit.std(axis=0)
@lru_property
def is_hitting(self) -> bool:
return len(self.points_hit) > 0
@lru_property
def is_empty(self) -> bool:
return not (len(self.points_hit) or len(self.points_miss))
@lru_property
def has_colors(self) -> bool:
return self.colors_hit is not None or self.colors_miss is not None
@lru_property
def has_normals(self) -> bool:
return self.normals_hit is not None
@lru_property
def has_uv(self) -> bool:
return self.uv_hits is not None
@lru_property
def has_miss_distances(self) -> bool:
return self.distances_miss is not None
@lru_property
def xyzrgb_hit(self) -> np.ndarray: # (N, 6)
if self.colors_hit is None: raise ValueError
return np.concatenate([self.points_hit, self.colors_hit], axis=1)
@lru_property
def xyzrgb_miss(self) -> np.ndarray: # (M, 6)
if self.colors_miss is None: raise ValueError
return np.concatenate([self.points_miss, self.colors_miss], axis=1)
@lru_property
def ray_dirs_hit(self) -> np.ndarray: # (N, 3)
out = self.points_hit - self.points_cam
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
return out
@lru_property
def ray_dirs_miss(self) -> np.ndarray: # (N, 3)
out = self.points_miss - self.points_cam
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
return out
@classmethod
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewScan":
if "phi" not in kw and not "theta" in kw:
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
scan = sample_single_view_scan_from_mesh(mesh, **kw)
if compute_miss_distances and scan.is_hitting:
scan.compute_miss_distances()
return scan
def to_uv_scan(self) -> "SingleViewUVScan":
return SingleViewUVScan.from_scan(self)
@classmethod
def from_uv_scan(self, uvscan: "SingleViewUVScan") -> "SingleViewUVScan":
return uvscan.to_scan()
# The same, but with support for pagination (should have been this way since the start...)
class SingleViewUVScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
# B may be (N) or (H, W), the latter may be flattened
hits : H5Array # (*B) dtype=bool
miss : H5Array # (*B) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backface hits)
points : H5Array # (*B, 3) on far plane if miss, NaN if neither hit or miss
normals : Optional[H5Array] # (*B, 3) NaN if not hit
colors : Optional[H5Array] # (*B, 3)
distances : Optional[H5Array] # (*B) NaN if not miss
cam_pos : Optional[H5ArrayNoSlice] # (3) or (*B, 3)
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
@classmethod
def from_scan(cls, scan: SingleViewScan):
if not scan.has_uv:
raise ValueError("Scan cloud has no UV data")
hits, miss = scan.uv_hits, scan.uv_miss
dtype = scan.points_hit.dtype
assert hits.ndim in (1, 2), hits.ndim
assert hits.shape == miss.shape, (hits.shape, miss.shape)
points = np.full((*hits.shape, 3), np.nan, dtype=dtype)
points[hits, :] = scan.points_hit
points[miss, :] = scan.points_miss
normals = None
if scan.has_normals:
normals = np.full((*hits.shape, 3), np.nan, dtype=dtype)
normals[hits, :] = scan.normals_hit
distances = None
if scan.has_miss_distances:
distances = np.full(hits.shape, np.nan, dtype=dtype)
distances[miss] = scan.distances_miss
colors = None
if scan.has_colors:
colors = np.full((*hits.shape, 3), np.nan, dtype=dtype)
if scan.colors_hit is not None:
colors[hits, :] = scan.colors_hit
if scan.colors_miss is not None:
colors[miss, :] = scan.colors_miss
return cls(
hits = hits,
miss = miss,
points = points,
normals = normals,
colors = colors,
distances = distances,
cam_pos = scan.cam_pos,
cam_mat4 = scan.cam_mat4,
proj_mat4 = scan.proj_mat4,
transforms = scan.transforms,
)
def to_scan(self) -> "SingleViewScan":
if not self.is_single_view: raise ValueError
return SingleViewScan(
points_hit = self.points [self.hits, :],
points_miss = self.points [self.miss, :],
normals_hit = self.normals [self.hits, :] if self.has_normals else None,
distances_miss = self.distances[self.miss] if self.has_miss_distances else None,
colors_hit = self.colors [self.hits, :] if self.has_colors else None,
colors_miss = self.colors [self.miss, :] if self.has_colors else None,
uv_hits = self.hits,
uv_miss = self.miss,
cam_pos = self.cam_pos,
cam_mat4 = self.cam_mat4,
proj_mat4 = self.proj_mat4,
transforms = self.transforms,
)
def to_mesh(self) -> trimesh.Trimesh:
faces: list[(tuple[int, int],)*3] = []
for x in range(self.hits.shape[0]-1):
for y in range(self.hits.shape[1]-1):
c11 = x, y
c12 = x, y+1
c22 = x+1, y+1
c21 = x+1, y
n = sum(map(self.hits.__getitem__, (c11, c12, c22, c21)))
if n == 3:
faces.append((*filter(self.hits.__getitem__, (c11, c12, c22, c21)),))
elif n == 4:
faces.append((c11, c12, c22))
faces.append((c11, c22, c21))
xy2idx = {c:i for i, c in enumerate(set(k for j in faces for k in j))}
assert self.colors is not None
return trimesh.Trimesh(
vertices = [self.points[i] for i in xy2idx.keys()],
vertex_colors = [self.colors[i] for i in xy2idx.keys()] if self.colors is not None else None,
faces = [tuple(xy2idx[i] for i in face) for face in faces],
)
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
unflat = self.hits.shape
flat = np.product(unflat)
out = self if inplace else self.copy(deep=False)
out.points = T.transform_points(self.points .reshape((*flat, 3)), mat4).reshape((*unflat, 3))
out.normals = T.transform_points(self.normals.reshape((*flat, 3)), mat4).reshape((*unflat, 3)) if self.normals_hit is not None else None
out.distances = self.distances_miss * scale_xyz
out.cam_pos = T.transform_points(self.cam_pos[None, ...], mat4)[0]
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
return out
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False, surface_points: Optional[np.ndarray] = None) -> _T:
assert not self.has_miss_distances
shape = self.hits.shape
out = self.copy(deep=deep) if copy else self
out.distances = np.zeros(shape, dtype=self.points.dtype)
if self.is_hitting:
out.distances[self.miss] \
= distance_from_rays_to_point_cloud(
ray_origins = self.cam_pos_unsqueezed_miss,
ray_dirs = self.ray_dirs_miss,
points = surface_points if surface_points is not None else self.points[self.hits],
)
return out
def fill_missing_points(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
"""
Fill in missing points as hitting the far plane.
"""
if not self.is_2d:
raise ValueError("Cannot fill missing points for non-2d scan!")
if not self.is_single_view:
raise ValueError("Cannot fill missing points for non-single-view scans!")
if self.cam_mat4 is None:
raise ValueError("cam_mat4 is None")
if self.proj_mat4 is None:
raise ValueError("proj_mat4 is None")
uv = np.argwhere(self.missing).astype(self.points.dtype)
uv[:, 0] /= (self.missing.shape[1] - 1) / 2
uv[:, 1] /= (self.missing.shape[0] - 1) / 2
uv -= 1
uv = np.stack((
uv[:, 1],
-uv[:, 0],
np.ones(uv.shape[0]), # far clipping plane
np.ones(uv.shape[0]), # homogeneous coordinate
), axis=-1)
uv = uv @ (self.cam_mat4 @ np.linalg.inv(self.proj_mat4)).T
out = self.copy(deep=deep) if copy else self
out.points[self.missing, :] = uv[:, :3] / uv[:, 3][:, None]
return out
@lru_property
def is_hitting(self) -> bool:
return np.any(self.hits)
@lru_property
def has_colors(self) -> bool:
return not self.colors is None
@lru_property
def has_normals(self) -> bool:
return not self.normals is None
@lru_property
def has_miss_distances(self) -> bool:
return not self.distances is None
@lru_property
def any_missing(self) -> bool:
return np.any(self.missing)
@lru_property
def has_missing(self) -> bool:
return self.any_missing and not np.any(np.isnan(self.points[self.missing]))
@lru_property
def cam_pos_unsqueezed(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos
else:
cam_pos = self.cam_pos
for _ in range(self.hits.ndim):
cam_pos = cam_pos[None, ...]
return cam_pos
@lru_property
def cam_pos_unsqueezed_hit(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos[self.hits, :]
else:
return self.cam_pos[None, :]
@lru_property
def cam_pos_unsqueezed_miss(self) -> H5Array:
if self.cam_pos.ndim != 1:
return self.cam_pos[self.miss, :]
else:
return self.cam_pos[None, :]
@lru_property
def ray_dirs(self) -> H5Array:
return (self.points - self.cam_pos_unsqueezed) * (1 / self.depths[..., None])
@lru_property
def ray_dirs_hit(self) -> H5Array:
out = self.points[self.hits, :] - self.cam_pos_unsqueezed_hit
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
return out
@lru_property
def ray_dirs_miss(self) -> H5Array:
out = self.points[self.miss, :] - self.cam_pos_unsqueezed_miss
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
return out
@lru_property
def depths(self) -> H5Array:
return np.linalg.norm(self.points - self.cam_pos_unsqueezed, axis=-1)
@lru_property
def missing(self) -> H5Array:
return ~(self.hits | self.miss)
@classmethod
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
if "phi" not in kw and not "theta" in kw:
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
scan = sample_single_view_scan_from_mesh(mesh, **kw).to_uv_scan()
if compute_miss_distances:
scan.compute_miss_distances()
assert scan.is_2d
return scan
@classmethod
def from_mesh_sphere_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
scan = sample_sphere_view_scan_from_mesh(mesh, **kw)
if compute_miss_distances:
surface_points = None
if scan.hits.sum() > mesh.vertices.shape[0]:
surface_points = mesh.vertices.astype(scan.points.dtype)
if not kw.get("no_unit_sphere", False):
translation, scale = compute_unit_sphere_transform(mesh, dtype=scan.points.dtype)
surface_points = (surface_points + translation) * scale
scan.compute_miss_distances(surface_points=surface_points)
assert scan.is_flat
return scan
def flatten_and_permute_(self: _T, copy=False) -> _T: # inplace by default
n_items = np.product(self.hits.shape)
permutation = np.random.permutation(n_items)
out = self.copy(deep=False) if copy else self
out.hits = out.hits .reshape((n_items, ))[permutation]
out.miss = out.miss .reshape((n_items, ))[permutation]
out.points = out.points .reshape((n_items, 3))[permutation, :]
out.normals = out.normals .reshape((n_items, 3))[permutation, :] if out.has_normals else None
out.colors = out.colors .reshape((n_items, 3))[permutation, :] if out.has_colors else None
out.distances = out.distances.reshape((n_items, ))[permutation] if out.has_miss_distances else None
return out
@property
def is_single_view(self) -> bool:
return np.product(self.cam_pos.shape[:-1]) == 1 if not self.cam_pos is None else True
@property
def is_flat(self) -> bool:
return len(self.hits.shape) == 1
@property
def is_2d(self) -> bool:
return len(self.hits.shape) == 2
# transforms can be found in pytorch3d.transforms and in open3d
# and in trimesh.transformations
def sample_single_view_scans_from_mesh(
mesh : Trimesh,
*,
n_batches : int,
scan_resolution : int = 400,
compute_normals : bool = False,
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
camera_distance : float = 2,
no_filter_backhits : bool = False,
) -> Iterable[SingleViewScan]:
normalized_mesh_cache = []
for _ in range(n_batches):
theta, phi = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
yield sample_single_view_scan_from_mesh(
mesh = mesh,
phi = phi,
theta = theta,
_mesh_is_normalized = False,
scan_resolution = scan_resolution,
compute_normals = compute_normals,
fov = fov,
camera_distance = camera_distance,
no_filter_backhits = no_filter_backhits,
_mesh_cache = normalized_mesh_cache,
)
def sample_single_view_scan_from_mesh(
mesh : Trimesh,
*,
phi : float,
theta : float,
scan_resolution : int = 200,
compute_normals : bool = False,
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
camera_distance : float = 2,
no_filter_backhits : bool = False,
no_unit_sphere : bool = False,
dtype : type = np.float32,
_mesh_cache : Optional[list] = None, # provide a list if mesh is reused
) -> SingleViewScan:
# scale and center to unit sphere
is_cache = isinstance(_mesh_cache, list)
if is_cache and _mesh_cache and _mesh_cache[0] is mesh:
_, mesh, translation, scale = _mesh_cache
else:
if is_cache:
if _mesh_cache:
_mesh_cache.clear()
_mesh_cache.append(mesh)
translation, scale = compute_unit_sphere_transform(mesh)
mesh = mesh_to_sdf.scale_to_unit_sphere(mesh)
if is_cache:
_mesh_cache.extend((mesh, translation, scale))
z_near = 1
z_far = 3
cam_mat4 = sdf_scan.get_camera_transform_looking_at_origin(phi, theta, camera_distance=camera_distance)
cam_pos = cam_mat4 @ np.array([0, 0, 0, 1])
scan = sdf_scan.Scan(mesh,
camera_transform = cam_mat4,
resolution = scan_resolution,
calculate_normals = compute_normals,
fov = fov,
z_near = z_near,
z_far = z_far,
no_flip_backfaced_normals = True
)
# all the scan rays that hit the far plane, based on sdf_scan.Scan.__init__
misses = np.argwhere(scan.depth_buffer == 0)
points_miss = np.ones((misses.shape[0], 4))
points_miss[:, [1, 0]] = misses.astype(float) / (scan_resolution -1) * 2 - 1
points_miss[:, 1] *= -1
points_miss[:, 2] = 1 # far plane in clipping space
points_miss = points_miss @ (cam_mat4 @ np.linalg.inv(scan.projection_matrix)).T
points_miss /= points_miss[:, 3][:, np.newaxis]
points_miss = points_miss[:, :3]
uv_hits = scan.depth_buffer != 0
uv_miss = ~uv_hits
if not no_filter_backhits:
if not compute_normals:
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
# inner product
mask = np.einsum('ij,ij->i', scan.points - cam_pos[:3][None, :], scan.normals) < 0
scan.points = scan.points [mask, :]
scan.normals = scan.normals[mask, :]
uv_hits[uv_hits] = mask
transforms = {}
# undo unit-sphere transform
if no_unit_sphere:
scan.points = scan.points * (1 / scale) - translation
points_miss = points_miss * (1 / scale) - translation
cam_pos[:3] = cam_pos[:3] * (1 / scale) - translation
cam_mat4[:3, :] *= 1 / scale
cam_mat4[:3, 3] -= translation
transforms["unit_sphere"] = T.scale_and_translate(scale=scale, translate=translation)
transforms["model"] = np.eye(4)
else:
transforms["model"] = np.linalg.inv(T.scale_and_translate(scale=scale, translate=translation))
transforms["unit_sphere"] = np.eye(4)
return SingleViewScan(
normals_hit = scan.normals .astype(dtype),
points_hit = scan.points .astype(dtype),
points_miss = points_miss .astype(dtype),
distances_miss = None,
colors_hit = None,
colors_miss = None,
uv_hits = uv_hits .astype(bool),
uv_miss = uv_miss .astype(bool),
cam_pos = cam_pos[:3] .astype(dtype),
cam_mat4 = cam_mat4 .astype(dtype),
proj_mat4 = scan.projection_matrix .astype(dtype),
transforms = {k:v.astype(dtype) for k, v in transforms.items()},
)
def sample_sphere_view_scan_from_mesh(
mesh : Trimesh,
*,
sphere_points : int = 4000, # resulting rays are n*(n-1)
compute_normals : bool = False,
no_filter_backhits : bool = False,
no_unit_sphere : bool = False,
no_permute : bool = False,
dtype : type = np.float32,
**kw,
) -> SingleViewUVScan:
translation, scale = compute_unit_sphere_transform(mesh, dtype=dtype)
# get unit-sphere points, then transform to model space
two_sphere = generate_equidistant_sphere_rays(sphere_points, **kw).astype(dtype) # (n*(n-1), 2, 3)
two_sphere = two_sphere / scale - translation # we transform after cache lookup
if mesh.ray.__class__.__module__.split(".")[-1] != "ray_pyembree":
warnings.warn("Pyembree not found, the ray-tracing will be SLOW!")
(
locations,
index_ray,
index_tri,
) = mesh.ray.intersects_location(
two_sphere[:, 0, :],
two_sphere[:, 1, :] - two_sphere[:, 0, :], # direction, not target coordinate
multiple_hits=False,
)
if compute_normals:
location_normals = mesh.face_normals[index_tri]
batch = two_sphere.shape[:1]
hits = np.zeros((*batch,), dtype=np.bool)
miss = np.ones((*batch,), dtype=np.bool)
cam_pos = two_sphere[:, 0, :]
intersections = two_sphere[:, 1, :] # far-plane, effectively
normals = np.zeros((*batch, 3), dtype=dtype)
index_ray_front = index_ray
if not no_filter_backhits:
if not compute_normals:
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
mask = ((intersections[index_ray] - cam_pos[index_ray]) * location_normals).sum(axis=-1) <= 0
index_ray_front = index_ray[mask]
hits[index_ray_front] = True
miss[index_ray] = False
intersections[index_ray] = locations
normals[index_ray] = location_normals
if not no_permute:
assert len(batch) == 1, batch
permutation = np.random.permutation(*batch)
hits = hits [permutation]
miss = miss [permutation]
intersections = intersections[permutation, :]
normals = normals [permutation, :]
cam_pos = cam_pos [permutation, :]
# apply unit sphere transform
if not no_unit_sphere:
intersections = (intersections + translation) * scale
cam_pos = (cam_pos + translation) * scale
return SingleViewUVScan(
hits = hits,
miss = miss,
points = intersections,
normals = normals,
colors = None, # colors
distances = None,
cam_pos = cam_pos,
cam_mat4 = None,
proj_mat4 = None,
transforms = {},
)
def distance_from_rays_to_point_cloud(
ray_origins : np.ndarray, # (*A, 3)
ray_dirs : np.ndarray, # (*A, 3)
points : np.ndarray, # (*B, 3)
dirs_normalized : bool = False,
n_steps : int = 40,
) -> np.ndarray: # (A)
# anything outside of this volume will never constribute to the result
max_norm = max(
np.linalg.norm(ray_origins, axis=-1).max(),
np.linalg.norm(points, axis=-1).max(),
) * 1.02
if not dirs_normalized:
ray_dirs = ray_dirs / np.linalg.norm(ray_dirs, axis=-1)[..., None]
# deal with single-view clouds
if ray_origins.shape != ray_dirs.shape:
ray_origins = np.broadcast_to(ray_origins, ray_dirs.shape)
n_points = np.product(points.shape[:-1])
use_faiss = n_points > 160000*4
if not use_faiss:
index = BallTree(points)
else:
# http://ann-benchmarks.com/index.html
assert np.issubdtype(points.dtype, np.float32)
assert np.issubdtype(ray_origins.dtype, np.float32)
assert np.issubdtype(ray_dirs.dtype, np.float32)
index = faiss.index_factory(points.shape[-1], "NSG32,Flat") # https://github.com/facebookresearch/faiss/wiki/The-index-factory
index.nprobe = 5 # 10 # default is 1
index.train(points)
index.add(points)
if not use_faiss:
min_d, min_n = index.query(ray_origins, k=1, return_distance=True)
else:
min_d, min_n = index.search(ray_origins, k=1)
min_d = np.sqrt(min_d)
acc_d = min_d.copy()
for step in range(1, n_steps+1):
query_points = ray_origins + acc_d * ray_dirs
if max_norm is not None:
qmask = np.linalg.norm(query_points, axis=-1) < max_norm
if not qmask.any(): break
query_points = query_points[qmask]
else:
qmask = slice(None)
if not use_faiss:
current_d, current_n = index.query(query_points, k=1, return_distance=True)
else:
current_d, current_n = index.search(query_points, k=1)
current_d = np.sqrt(current_d)
if max_norm is not None:
min_d[qmask] = np.minimum(current_d, min_d[qmask])
new_min_mask = min_d[qmask] == current_d
qmask2 = qmask.copy()
qmask2[qmask2] = new_min_mask[..., 0]
min_n[qmask2] = current_n[new_min_mask[..., 0]]
acc_d[qmask] += current_d * 0.25
else:
np.minimum(current_d, min_d, out=min_d)
new_min_mask = min_d == current_d
min_n[new_min_mask] = current_n[new_min_mask]
acc_d += current_d * 0.25
closest_points = points[min_n[:, 0], :] # k=1
distances = np.linalg.norm(np.cross(closest_points - ray_origins, ray_dirs, axis=-1), axis=-1)
return distances
# helpers
@compose(np.array) # make copy to avoid lru cache mutation
@lru_cache(maxsize=1)
def generate_equidistant_sphere_rays(n : int, **kw) -> np.ndarray: # output (n*n(-1)) rays, n may be off
sphere_points = points.generate_equidistant_sphere_points(n=n, **kw)
indices = np.indices((len(sphere_points),))[0] # (N)
# cartesian product
cprod = np.transpose([np.tile(indices, len(indices)), np.repeat(indices, len(indices))]) # (N**2, 2)
# filter repeated combinations
permutations = cprod[cprod[:, 0] != cprod[:, 1], :] # (N*(N-1), 2)
# lookup sphere points
two_sphere = sphere_points[permutations, :] # (N*(N-1), 2, 3)
return two_sphere
def compute_unit_sphere_transform(mesh: Trimesh, *, dtype=type) -> tuple[np.ndarray, float]:
"""
returns translation and scale which mesh_to_sdf applies to meshes before computing their SDF cloud
"""
# the transformation applied by mesh_to_sdf.scale_to_unit_sphere(mesh)
translation = -mesh.bounding_box.centroid
scale = 1 / np.max(np.linalg.norm(mesh.vertices + translation, axis=1))
if dtype is not None:
translation = translation.astype(dtype)
scale = scale .astype(dtype)
return translation, scale

View File

@@ -0,0 +1,6 @@
__doc__ = """
Some helper types.
"""
class MalformedMesh(Exception):
pass

28
ifield/data/config.py Normal file
View File

@@ -0,0 +1,28 @@
from ..utils.helpers import make_relative
from pathlib import Path
from typing import Optional
import os
import warnings
def data_path_get(dataset_name: str, no_warn: bool = False) -> Path:
dataset_envvar = f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"
if dataset_envvar in os.environ:
data_path = Path(os.environ[dataset_envvar])
elif "IFIELD_DATA_MODELS" in os.environ:
data_path = Path(os.environ["IFIELD_DATA_MODELS"]) / dataset_name
else:
data_path = Path(__file__).resolve().parent.parent.parent / "data" / "models" / dataset_name
if not data_path.is_dir() and not no_warn:
warnings.warn(f"{make_relative(data_path, Path.cwd()).__str__()!r} is not a directory!")
return data_path
def data_path_persist(dataset_name: Optional[str], path: os.PathLike) -> os.PathLike:
"Persist the datapath, ensuring subprocesses also will use it. The path passes through."
if dataset_name is None:
os.environ["IFIELD_DATA_MODELS"] = str(path)
else:
os.environ[f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"] = str(path)
return path

View File

@@ -0,0 +1,56 @@
from ..config import data_path_get, data_path_persist
from collections import namedtuple
import os
# Data source:
# http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
__ALL__ = ["config", "Model", "MODELS"]
Archive = namedtuple("Archive", "url fname download_size_str")
@(lambda x: x()) # singleton
class config:
DATA_PATH = property(
doc = """
Path to the dataset. The following envvars override it:
${IFIELD_DATA_MODELS}/coseg
${IFIELD_DATA_MODELS_COSEG}
""",
fget = lambda self: data_path_get ("coseg"),
fset = lambda self, path: data_path_persist("coseg", path),
)
@property
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
return [
self.DATA_PATH / "downloaded.json",
]
SHAPES: dict[str, Archive] = {
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/shapes.zip", "candelabra-shapes.zip", "3,3M"),
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/shapes.zip", "chair-shapes.zip", "3,2M"),
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/shapes.zip", "four-legged-shapes.zip", "2,9M"),
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/shapes.zip", "goblets-shapes.zip", "500K"),
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/shapes.zip", "guitars-shapes.zip", "1,9M"),
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/shapes.zip", "lampes-shapes.zip", "2,4M"),
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/shapes.zip", "vases-shapes.zip", "5,5M"),
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/shapes.zip", "irons-shapes.zip", "1,2M"),
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/shapes.zip", "tele-aliens-shapes.zip", "15M"),
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/shapes.zip", "large-vases-shapes.zip", "6,2M"),
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/shapes.zip", "large-chairs-shapes.zip", "14M"),
}
GROUND_TRUTHS: dict[str, Archive] = {
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/gt.zip", "candelabra-gt.zip", "68K"),
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/gt.zip", "chair-gt.zip", "20K"),
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/gt.zip", "four-legged-gt.zip", "24K"),
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/gt.zip", "goblets-gt.zip", "4,0K"),
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/gt.zip", "guitars-gt.zip", "12K"),
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/gt.zip", "lampes-gt.zip", "60K"),
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/gt.zip", "vases-gt.zip", "40K"),
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/gt.zip", "irons-gt.zip", "8,0K"),
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/gt.zip", "tele-aliens-gt.zip", "72K"),
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/gt.zip", "large-vases-gt.zip", "68K"),
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/gt.zip", "large-chairs-gt.zip", "116K"),
}

View File

@@ -0,0 +1,135 @@
#!/usr/bin/env python3
from . import config
from ...utils.helpers import make_relative
from ..common import download
from pathlib import Path
from textwrap import dedent
import argparse
import io
import zipfile
def is_downloaded(*a, **kw):
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
def download_and_extract(target_dir: Path, url_dict: dict[str, str], *, force=False, silent=False) -> bool:
target_dir.mkdir(parents=True, exist_ok=True)
ret = False
for url, fname in url_dict.items():
if not force:
if is_downloaded(target_dir, url): continue
if not download.check_url(url):
print("ERROR:", url)
continue
ret = True
if force or not (target_dir / "archives" / fname).is_file():
data = download.download_data(url, silent=silent, label=fname)
assert url.endswith(".zip")
print("writing...")
(target_dir / "archives").mkdir(parents=True, exist_ok=True)
with (target_dir / "archives" / fname).open("wb") as f:
f.write(data)
del data
print(f"extracting {fname}...")
with zipfile.ZipFile(target_dir / "archives" / fname, 'r') as f:
f.extractall(target_dir / Path(fname).stem.removesuffix("-shapes").removesuffix("-gt"))
is_downloaded(target_dir, url, add=True)
return ret
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dedent("""
Download The COSEG Shape Dataset.
More info: http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
Example:
download-coseg --shapes chairs
"""), formatter_class=argparse.RawTextHelpFormatter)
arg = parser.add_argument
arg("sets", nargs="*", default=[],
help="Which set to download, defaults to none.")
arg("--all", action="store_true",
help="Download all sets")
arg("--dir", default=str(config.DATA_PATH),
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
arg("--shapes", action="store_true",
help="Download the 3d shapes for each chosen set")
arg("--gts", action="store_true",
help="Download the ground-truth segmentation data for each chosen set")
arg("--list", action="store_true",
help="Lists all the sets")
arg("--list-urls", action="store_true",
help="Lists the urls to download")
arg("--list-sizes", action="store_true",
help="Lists the download size of each set")
arg("--silent", action="store_true",
help="")
arg("--force", action="store_true",
help="Download again even if already downloaded")
return parser
# entrypoint
def cli(parser=make_parser()):
args = parser.parse_args()
assert set(config.SHAPES.keys()) == set(config.GROUND_TRUTHS.keys())
set_names = sorted(set(args.sets))
if args.all:
assert not set_names, "--all is mutually exclusive from manually selected sets"
set_names = sorted(config.SHAPES.keys())
if args.list:
print(*config.SHAPES.keys(), sep="\n")
exit()
if args.list_sizes:
print(*(f"{set_name:<15}{config.SHAPES[set_name].download_size_str}" for set_name in (set_names or config.SHAPES.keys())), sep="\n")
exit()
try:
url_dict \
= {config.SHAPES[set_name].url : config.SHAPES[set_name].fname for set_name in set_names if args.shapes} \
| {config.GROUND_TRUTHS[set_name].url : config.GROUND_TRUTHS[set_name].fname for set_name in set_names if args.gts}
except KeyError:
print("Error: unrecognized object name:", *set(set_names).difference(config.SHAPES.keys()), sep="\n")
exit(1)
if not url_dict:
if set_names and not (args.shapes or args.gts):
print("Error: Provide at least one of --shapes of --gts")
else:
print("Error: No object set was selected for download!")
exit(1)
if args.list_urls:
print(*url_dict.keys(), sep="\n")
exit()
print("Download start")
any_downloaded = download_and_extract(
target_dir = Path(args.dir),
url_dict = url_dict,
force = args.force,
silent = args.silent,
)
if not any_downloaded:
print("Everything has already been downloaded, skipping.")
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,137 @@
#!/usr/bin/env python3
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
from . import config, read
from ...utils.helpers import make_relative
from pathlib import Path
from textwrap import dedent
import argparse
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dedent("""
Preprocess the COSEG dataset. Depends on `download-coseg --shapes ...` having been run.
"""), formatter_class=argparse.RawTextHelpFormatter)
arg = parser.add_argument # brevity
arg("items", nargs="*", default=[],
help="Which object-set[/model-id] to process, defaults to all downloaded. Format: OBJECT-SET[/MODEL-ID]")
arg("--dir", default=str(config.DATA_PATH),
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
arg("--force", action="store_true",
help="Overwrite existing files")
arg("--list-models", action="store_true",
help="List the downloaded models available for preprocessing")
arg("--list-object-sets", action="store_true",
help="List the downloaded object-sets available for preprocessing")
arg("--list-pages", type=int, default=None,
help="List the downloaded models available for preprocessing, paginated into N pages.")
arg("--page", nargs=2, type=int, default=[0, 1],
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
arg2("--precompute-mesh-sphere-scan", action="store_true",
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
arg3 = parser.add_argument_group("modifiers").add_argument # brevity
arg3("--n-sphere-points", type=int, default=4000,
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
arg3("--compute-miss-distances", action="store_true",
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
arg3("--fill-missing-uv-points", action="store_true",
help="TODO")
arg3("--no-filter-backhits", action="store_true",
help="Do not filter scan hits on backside of mesh faces.")
arg3("--no-unit-sphere", action="store_true",
help="Do not center the objects to the unit sphere.")
arg3("--convert-ok", action="store_true",
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
arg3("--debug", action="store_true",
help="Abort on failiure.")
return parser
# entrypoint
def cli(parser=make_parser()):
args = parser.parse_args()
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list_models or args.list_object_sets or args.list_pages):
parser.error("no preprocessing target selected") # exits
config.DATA_PATH = Path(args.dir)
object_sets = [i for i in args.items if "/" not in i]
models = [i.split("/") for i in args.items if "/" in i]
# convert/expand synsets to models
# they are mutually exclusive
if object_sets: assert not models
if models: assert not object_sets
if not models:
models = read.list_model_ids(tuple(object_sets) or None)
if args.list_models:
try:
print(*(f"{object_set_id}/{model_id}" for object_set_id, model_id in models), sep="\n")
except BrokenPipeError:
pass
parser.exit()
if args.list_object_sets:
try:
print(*sorted(set(object_set_id for object_set_id, model_id in models)), sep="\n")
except BrokenPipeError:
pass
parser.exit()
if args.list_pages is not None:
try:
print(*(
f"--page {i} {args.list_pages} {object_set_id}/{model_id}"
for object_set_id, model_id in models
for i in range(args.list_pages)
), sep="\n")
except BrokenPipeError:
pass
parser.exit()
if args.precompute_mesh_sv_scan_clouds:
read.precompute_mesh_scan_point_clouds(
models,
compute_miss_distances = args.compute_miss_distances,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
convert_ok = args.convert_ok,
page = args.page,
force = args.force,
debug = args.debug,
)
if args.precompute_mesh_sv_scan_uvs:
read.precompute_mesh_scan_uvs(
models,
compute_miss_distances = args.compute_miss_distances,
fill_missing_points = args.fill_missing_uv_points,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
convert_ok = args.convert_ok,
page = args.page,
force = args.force,
debug = args.debug,
)
if args.precompute_mesh_sphere_scan:
read.precompute_mesh_sphere_scan(
models,
sphere_points = args.n_sphere_points,
compute_miss_distances = args.compute_miss_distances,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
page = args.page,
force = args.force,
debug = args.debug,
)
if __name__ == "__main__":
cli()

290
ifield/data/coseg/read.py Normal file
View File

@@ -0,0 +1,290 @@
from . import config
from ..common import points
from ..common import processing
from ..common.scan import SingleViewScan, SingleViewUVScan
from ..common.types import MalformedMesh
from functools import lru_cache
from typing import Optional, Iterable
import numpy as np
import trimesh
import trimesh.transformations as T
__doc__ = """
Here are functions for reading and preprocessing coseg benchmark data
There are essentially a few sets per object:
"img" - meaning the RGBD images (none found in coseg)
"mesh_scans" - meaning synthetic scans of a mesh
"""
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0)) # rotate to be upright in pyrender
MESH_POSE_CORRECTIONS = { # to gain a shared canonical orientation
("four-legged", 381): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 382): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
("four-legged", 383): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 384): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 385): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 386): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
("four-legged", 387): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
("four-legged", 388): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 389): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 390): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 391): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 392): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 393): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 394): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 395): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
("four-legged", 396): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
("four-legged", 397): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 398): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
("four-legged", 399): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
("four-legged", 400): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
}
ModelUid = tuple[str, int]
@lru_cache(maxsize=1)
def list_object_sets() -> list[str]:
return sorted(
object_set.name
for object_set in config.DATA_PATH.iterdir()
if (object_set / "shapes").is_dir() and object_set.name != "archive"
)
@lru_cache(maxsize=1)
def list_model_ids(object_sets: Optional[tuple[str]] = None) -> list[ModelUid]:
return sorted(
(object_set.name, int(model.stem))
for object_set in config.DATA_PATH.iterdir()
if (object_set / "shapes").is_dir() and object_set.name != "archive" and (object_sets is None or object_set.name in object_sets)
for model in (object_set / "shapes").iterdir()
if model.is_file() and model.suffix == ".off"
)
def list_model_id_strings(object_sets: Optional[tuple[str]] = None) -> list[str]:
return [model_uid_to_string(object_set_id, model_id) for object_set_id, model_id in list_model_ids(object_sets)]
def model_uid_to_string(object_set_id: str, model_id: int) -> str:
return f"{object_set_id}-{model_id}"
def model_id_string_to_uid(model_string_uid: str) -> ModelUid:
object_set, split, model = model_string_uid.rpartition("-")
assert split == "-"
return (object_set, int(model))
@lru_cache(maxsize=1)
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
return (
f"{'np'[theta>=0]}{abs(theta):.2f}"
f"{'np'[phi >=0]}{abs(phi) :.2f}"
).replace(".", "d")
@lru_cache(maxsize=1)
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
out = [
mesh_scan_identifier(phi=phi, theta=theta)
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
]
assert len(out) == len(set(out))
return out
# ===
def read_mesh(object_set_id: str, model_id: int) -> trimesh.Trimesh:
path = config.DATA_PATH / object_set_id / "shapes" / f"{model_id}.off"
if not path.is_file():
raise FileNotFoundError(f"{path = }")
try:
mesh = trimesh.load(path, force="mesh")
except Exception as e:
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
pose = MESH_POSE_CORRECTIONS.get((object_set_id, int(model_id)))
mesh.apply_transform(pose @ MESH_TRANSFORM_SKYWARD if pose is not None else MESH_TRANSFORM_SKYWARD)
return mesh
# === single-view scan clouds
def compute_mesh_scan_point_cloud(
object_set_id : str,
model_id : int,
phi : float,
theta : float,
*,
compute_miss_distances : bool = False,
fill_missing_points : bool = False,
compute_normals : bool = True,
convert_ok : bool = False,
**kw,
) -> SingleViewScan:
if convert_ok:
try:
return read_mesh_scan_uv(object_set_id, model_id, phi=phi, theta=theta).to_scan()
except FileNotFoundError:
pass
mesh = read_mesh(object_set_id, model_id)
scan = SingleViewScan.from_mesh_single_view(mesh,
phi = phi,
theta = theta,
compute_normals = compute_normals,
**kw,
)
if compute_miss_distances:
scan.compute_miss_distances()
if fill_missing_points:
scan.fill_missing_points()
return scan
def precompute_mesh_scan_point_clouds(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
assert len(cam_poses) == len(pose_identifiers)
paths = list_mesh_scan_point_cloud_h5_fnames(models, pose_identifiers, n_poses=n_poses)
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
pretty_identifiers = [
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
for object_set_id, model_id in models
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
]
mesh_cache = []
def computer(pretty_identifier: str) -> SingleViewScan:
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
theta, phi = cam_poses[int(index)]
return compute_mesh_scan_point_cloud(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_scan_point_cloud(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
if identifier is None:
if phi is None or theta is None:
raise ValueError("Provide either phi+theta or an identifier!")
identifier = mesh_scan_identifier(phi=phi, theta=theta)
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
return SingleViewScan.from_h5_file(file)
def list_mesh_scan_point_cloud_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
if identifiers is None:
identifiers = list_mesh_scan_identifiers(**kw)
return [
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
for object_set_id, model_id in models
for identifier in identifiers
]
# === single-view UV scan clouds
def compute_mesh_scan_uv(
object_set_id : str,
model_id : int,
phi : float,
theta : float,
*,
compute_miss_distances : bool = False,
fill_missing_points : bool = False,
compute_normals : bool = True,
convert_ok : bool = False,
**kw,
) -> SingleViewUVScan:
if convert_ok:
try:
return read_mesh_scan_point_cloud(object_set_id, model_id, phi=phi, theta=theta).to_uv_scan()
except FileNotFoundError:
pass
mesh = read_mesh(object_set_id, model_id)
scan = SingleViewUVScan.from_mesh_single_view(mesh,
phi = phi,
theta = theta,
compute_normals = compute_normals,
**kw,
)
if compute_miss_distances:
scan.compute_miss_distances()
if fill_missing_points:
scan.fill_missing_points()
return scan
def precompute_mesh_scan_uvs(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
assert len(cam_poses) == len(pose_identifiers)
paths = list_mesh_scan_uv_h5_fnames(models, pose_identifiers, n_poses=n_poses)
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
pretty_identifiers = [
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
for object_set_id, model_id in models
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
]
mesh_cache = []
def computer(pretty_identifier: str) -> SingleViewUVScan:
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
theta, phi = cam_poses[int(index)]
return compute_mesh_scan_uv(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_scan_uv(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
if identifier is None:
if phi is None or theta is None:
raise ValueError("Provide either phi+theta or an identifier!")
identifier = mesh_scan_identifier(phi=phi, theta=theta)
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
return SingleViewUVScan.from_h5_file(file)
def list_mesh_scan_uv_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
if identifiers is None:
identifiers = list_mesh_scan_identifiers(**kw)
return [
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
for object_set_id, model_id in models
for identifier in identifiers
]
# === sphere-view (UV) scan clouds
def compute_mesh_sphere_scan(
object_set_id : str,
model_id : int,
*,
compute_normals : bool = True,
**kw,
) -> SingleViewUVScan:
mesh = read_mesh(object_set_id, model_id)
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
compute_normals = compute_normals,
**kw,
)
return scan
def precompute_mesh_sphere_scan(models: Iterable[ModelUid], *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
"precomputes all sphere scan clouds and stores them as HDF5 datasets"
paths = list_mesh_sphere_scan_h5_fnames(models)
identifiers = [model_uid_to_string(*i) for i in models]
def computer(identifier: str) -> SingleViewScan:
object_set_id, model_id = model_id_string_to_uid(identifier)
return compute_mesh_sphere_scan(object_set_id, model_id, **kw)
return processing.precompute_data(computer, identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_mesh_sphere_scan(object_set_id: str, model_id: int) -> SingleViewUVScan:
file = config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
return SingleViewUVScan.from_h5_file(file)
def list_mesh_sphere_scan_h5_fnames(models: Iterable[ModelUid]) -> list[str]:
return [
config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
for object_set_id, model_id in models
]

View File

@@ -0,0 +1,76 @@
from ..config import data_path_get, data_path_persist
from collections import namedtuple
import os
# Data source:
# http://graphics.stanford.edu/data/3Dscanrep/
__ALL__ = ["config", "Model", "MODELS"]
@(lambda x: x()) # singleton
class config:
DATA_PATH = property(
doc = """
Path to the dataset. The following envvars override it:
${IFIELD_DATA_MODELS}/stanford
${IFIELD_DATA_MODELS_STANFORD}
""",
fget = lambda self: data_path_get ("stanford"),
fset = lambda self, path: data_path_persist("stanford", path),
)
@property
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
return [
self.DATA_PATH / "downloaded.json",
]
Model = namedtuple("Model", "url mesh_fname download_size_str")
MODELS: dict[str, Model] = {
"bunny": Model(
"http://graphics.stanford.edu/pub/3Dscanrep/bunny.tar.gz",
"bunny/reconstruction/bun_zipper.ply",
"4.89M",
),
"drill_bit": Model(
"http://graphics.stanford.edu/pub/3Dscanrep/drill.tar.gz",
"drill/reconstruction/drill_shaft_vrip.ply",
"555k",
),
"happy_buddha": Model(
# religious symbol
"http://graphics.stanford.edu/pub/3Dscanrep/happy/happy_recon.tar.gz",
"happy_recon/happy_vrip.ply",
"14.5M",
),
"dragon": Model(
# symbol of Chinese culture
"http://graphics.stanford.edu/pub/3Dscanrep/dragon/dragon_recon.tar.gz",
"dragon_recon/dragon_vrip.ply",
"11.2M",
),
"armadillo": Model(
"http://graphics.stanford.edu/pub/3Dscanrep/armadillo/Armadillo.ply.gz",
"armadillo.ply.gz",
"3.87M",
),
"lucy": Model(
# Christian angel
"http://graphics.stanford.edu/data/3Dscanrep/lucy.tar.gz",
"lucy.ply",
"322M",
),
"asian_dragon": Model(
# symbol of Chinese culture
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_dragon.ply.gz",
"xyzrgb_dragon.ply.gz",
"70.5M",
),
"thai_statue": Model(
# Hindu religious significance
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_statuette.ply.gz",
"xyzrgb_statuette.ply.gz",
"106M",
),
}

View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
from . import config
from ...utils.helpers import make_relative
from ..common import download
from pathlib import Path
from textwrap import dedent
from typing import Iterable
import argparse
import io
import tarfile
def is_downloaded(*a, **kw):
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
def download_and_extract(target_dir: Path, url_list: Iterable[str], *, force=False, silent=False) -> bool:
target_dir.mkdir(parents=True, exist_ok=True)
ret = False
for url in url_list:
if not force:
if is_downloaded(target_dir, url): continue
if not download.check_url(url):
print("ERROR:", url)
continue
ret = True
data = download.download_data(url, silent=silent, label=str(Path(url).name))
print("extracting...")
if url.endswith(".ply.gz"):
fname = target_dir / "meshes" / url.split("/")[-1].lower()
fname.parent.mkdir(parents=True, exist_ok=True)
with fname.open("wb") as f:
f.write(data)
elif url.endswith(".tar.gz"):
with tarfile.open(fileobj=io.BytesIO(data)) as tar:
for member in tar.getmembers():
if not member.isfile(): continue
if member.name.startswith("/"): continue
if member.name.startswith("."): continue
if Path(member.name).name.startswith("."): continue
tar.extract(member, target_dir / "meshes")
del tar
else:
raise NotImplementedError(f"Extraction for {str(Path(url).name)} unknown")
is_downloaded(target_dir, url, add=True)
del data
return ret
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dedent("""
Download The Stanford 3D Scanning Repository models.
More info: http://graphics.stanford.edu/data/3Dscanrep/
Example:
download-stanford bunny
"""), formatter_class=argparse.RawTextHelpFormatter)
arg = parser.add_argument
arg("objects", nargs="*", default=[],
help="Which objects to download, defaults to none.")
arg("--all", action="store_true",
help="Download all objects")
arg("--dir", default=str(config.DATA_PATH),
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
arg("--list", action="store_true",
help="Lists all the objects")
arg("--list-urls", action="store_true",
help="Lists the urls to download")
arg("--list-sizes", action="store_true",
help="Lists the download size of each model")
arg("--silent", action="store_true",
help="")
arg("--force", action="store_true",
help="Download again even if already downloaded")
return parser
# entrypoint
def cli(parser=make_parser()):
args = parser.parse_args()
obj_names = sorted(set(args.objects))
if args.all:
assert not obj_names
obj_names = sorted(config.MODELS.keys())
if not obj_names and args.list_urls: config.MODELS.keys()
if args.list:
print(*config.MODELS.keys(), sep="\n")
exit()
if args.list_sizes:
print(*(f"{obj_name:<15}{config.MODELS[obj_name].download_size_str}" for obj_name in (obj_names or config.MODELS.keys())), sep="\n")
exit()
try:
url_list = [config.MODELS[obj_name].url for obj_name in obj_names]
except KeyError:
print("Error: unrecognized object name:", *set(obj_names).difference(config.MODELS.keys()), sep="\n")
exit(1)
if not url_list:
print("Error: No object set was selected for download!")
exit(1)
if args.list_urls:
print(*url_list, sep="\n")
exit()
print("Download start")
any_downloaded = download_and_extract(
target_dir = Path(args.dir),
url_list = url_list,
force = args.force,
silent = args.silent,
)
if not any_downloaded:
print("Everything has already been downloaded, skipping.")
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,118 @@
#!/usr/bin/env python3
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
from . import config, read
from ...utils.helpers import make_relative
from pathlib import Path
from textwrap import dedent
import argparse
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dedent("""
Preprocess the Stanford models. Depends on `download-stanford` having been run.
"""), formatter_class=argparse.RawTextHelpFormatter)
arg = parser.add_argument # brevity
arg("objects", nargs="*", default=[],
help="Which objects to process, defaults to all downloaded")
arg("--dir", default=str(config.DATA_PATH),
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
arg("--force", action="store_true",
help="Overwrite existing files")
arg("--list", action="store_true",
help="List the downloaded models available for preprocessing")
arg("--list-pages", type=int, default=None,
help="List the downloaded models available for preprocessing, paginated into N pages.")
arg("--page", nargs=2, type=int, default=[0, 1],
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
arg2("--precompute-mesh-sphere-scan", action="store_true",
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
arg3 = parser.add_argument_group("ray-scan modifiers").add_argument # brevity
arg3("--n-sphere-points", type=int, default=4000,
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
arg3("--compute-miss-distances", action="store_true",
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
arg3("--fill-missing-uv-points", action="store_true",
help="TODO")
arg3("--no-filter-backhits", action="store_true",
help="Do not filter scan hits on backside of mesh faces.")
arg3("--no-unit-sphere", action="store_true",
help="Do not center the objects to the unit sphere.")
arg3("--convert-ok", action="store_true",
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
arg3("--debug", action="store_true",
help="Abort on failiure.")
arg5 = parser.add_argument_group("Shared modifiers").add_argument # brevity
arg5("--scan-resolution", type=int, default=400,
help="The resolution of the depth map rendered to sample points. Becomes x*x")
return parser
# entrypoint
def cli(parser: argparse.ArgumentParser = make_parser()):
args = parser.parse_args()
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list or args.list_pages):
parser.error("no preprocessing target selected") # exits
config.DATA_PATH = Path(args.dir)
obj_names = args.objects or read.list_object_names()
if args.list:
print(*obj_names, sep="\n")
parser.exit()
if args.list_pages is not None:
print(*(
f"--page {i} {args.list_pages} {obj_name}"
for obj_name in obj_names
for i in range(args.list_pages)
), sep="\n")
parser.exit()
if args.precompute_mesh_sv_scan_clouds:
read.precompute_mesh_scan_point_clouds(
obj_names,
compute_miss_distances = args.compute_miss_distances,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
convert_ok = args.convert_ok,
page = args.page,
force = args.force,
debug = args.debug,
)
if args.precompute_mesh_sv_scan_uvs:
read.precompute_mesh_scan_uvs(
obj_names,
compute_miss_distances = args.compute_miss_distances,
fill_missing_points = args.fill_missing_uv_points,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
convert_ok = args.convert_ok,
page = args.page,
force = args.force,
debug = args.debug,
)
if args.precompute_mesh_sphere_scan:
read.precompute_mesh_sphere_scan(
obj_names,
sphere_points = args.n_sphere_points,
compute_miss_distances = args.compute_miss_distances,
no_filter_backhits = args.no_filter_backhits,
no_unit_sphere = args.no_unit_sphere,
page = args.page,
force = args.force,
debug = args.debug,
)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,251 @@
from . import config
from ..common import points
from ..common import processing
from ..common.scan import SingleViewScan, SingleViewUVScan
from ..common.types import MalformedMesh
from functools import lru_cache, wraps
from typing import Optional, Iterable
from pathlib import Path
import gzip
import numpy as np
import trimesh
import trimesh.transformations as T
__doc__ = """
Here are functions for reading and preprocessing shapenet benchmark data
There are essentially a few sets per object:
"img" - meaning the RGBD images (none found in stanford)
"mesh_scans" - meaning synthetic scans of a mesh
"""
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0))
MESH_TRANSFORM_CANONICAL = { # to gain a shared canonical orientation
"armadillo" : T.rotation_matrix(np.pi, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
"asian_dragon" : T.rotation_matrix(-np.pi/2, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
"bunny" : MESH_TRANSFORM_SKYWARD,
"dragon" : MESH_TRANSFORM_SKYWARD,
"drill_bit" : MESH_TRANSFORM_SKYWARD,
"happy_buddha" : MESH_TRANSFORM_SKYWARD,
"lucy" : T.rotation_matrix(np.pi, (0, 0, 1)),
"thai_statue" : MESH_TRANSFORM_SKYWARD,
}
def list_object_names() -> list[str]:
# downloaded only:
return [
i for i, v in config.MODELS.items()
if (config.DATA_PATH / "meshes" / v.mesh_fname).is_file()
]
@lru_cache(maxsize=1)
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)#, shift_theta=True
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
return (
f"{'np'[theta>=0]}{abs(theta):.2f}"
f"{'np'[phi >=0]}{abs(phi) :.2f}"
).replace(".", "d")
@lru_cache(maxsize=1)
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
out = [
mesh_scan_identifier(phi=phi, theta=theta)
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
]
assert len(out) == len(set(out))
return out
# ===
@lru_cache(maxsize=1)
def read_mesh(obj_name: str) -> trimesh.Trimesh:
path = config.DATA_PATH / "meshes" / config.MODELS[obj_name].mesh_fname
if not path.exists():
raise FileNotFoundError(f"{obj_name = } -> {str(path) = }")
try:
if path.suffixes[-1] == ".gz":
with gzip.open(path, "r") as f:
mesh = trimesh.load(f, file_type="".join(path.suffixes[:-1])[1:])
else:
mesh = trimesh.load(path)
except Exception as e:
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
# rotate to be upright in pyrender
mesh.apply_transform(MESH_TRANSFORM_CANONICAL.get(obj_name, MESH_TRANSFORM_SKYWARD))
return mesh
# === single-view scan clouds
def compute_mesh_scan_point_cloud(
obj_name : str,
*,
phi : float,
theta : float,
compute_miss_distances : bool = False,
compute_normals : bool = True,
convert_ok : bool = False, # this does not respect the other hparams
**kw,
) -> SingleViewScan:
if convert_ok:
try:
return read_mesh_scan_uv(obj_name, phi=phi, theta=theta).to_scan()
except FileNotFoundError:
pass
mesh = read_mesh(obj_name)
return SingleViewScan.from_mesh_single_view(mesh,
phi = phi,
theta = theta,
compute_normals = compute_normals,
compute_miss_distances = compute_miss_distances,
**kw,
)
def precompute_mesh_scan_point_clouds(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
cam_poses = list_mesh_scan_sphere_coords(n_poses)
pose_identifiers = list_mesh_scan_identifiers (n_poses)
assert len(cam_poses) == len(pose_identifiers)
paths = list_mesh_scan_point_cloud_h5_fnames(obj_names, pose_identifiers)
mlen = max(map(len, config.MODELS.keys()))
pretty_identifiers = [
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
for obj_name in obj_names
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
]
mesh_cache = []
@wraps(compute_mesh_scan_point_cloud)
def computer(pretty_identifier: str) -> SingleViewScan:
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
theta, phi = cam_poses[int(index)]
return compute_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_scan_point_cloud(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
if identifier is None:
if phi is None or theta is None:
raise ValueError("Provide either phi+theta or an identifier!")
identifier = mesh_scan_identifier(phi=phi, theta=theta)
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
if not file.exists(): raise FileNotFoundError(str(file))
return SingleViewScan.from_h5_file(file)
def list_mesh_scan_point_cloud_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
if identifiers is None:
identifiers = list_mesh_scan_identifiers(**kw)
return [
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
for obj_name in obj_names
for identifier in identifiers
]
# === single-view UV scan clouds
def compute_mesh_scan_uv(
obj_name : str,
*,
phi : float,
theta : float,
compute_miss_distances : bool = False,
fill_missing_points : bool = False,
compute_normals : bool = True,
convert_ok : bool = False,
**kw,
) -> SingleViewUVScan:
if convert_ok:
try:
return read_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta).to_uv_scan()
except FileNotFoundError:
pass
mesh = read_mesh(obj_name)
scan = SingleViewUVScan.from_mesh_single_view(mesh,
phi = phi,
theta = theta,
compute_normals = compute_normals,
**kw,
)
if compute_miss_distances:
scan.compute_miss_distances()
if fill_missing_points:
scan.fill_missing_points()
return scan
def precompute_mesh_scan_uvs(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
cam_poses = list_mesh_scan_sphere_coords(n_poses)
pose_identifiers = list_mesh_scan_identifiers (n_poses)
assert len(cam_poses) == len(pose_identifiers)
paths = list_mesh_scan_uv_h5_fnames(obj_names, pose_identifiers)
mlen = max(map(len, config.MODELS.keys()))
pretty_identifiers = [
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
for obj_name in obj_names
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
]
mesh_cache = []
@wraps(compute_mesh_scan_uv)
def computer(pretty_identifier: str) -> SingleViewScan:
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
theta, phi = cam_poses[int(index)]
return compute_mesh_scan_uv(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
def read_mesh_scan_uv(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
if identifier is None:
if phi is None or theta is None:
raise ValueError("Provide either phi+theta or an identifier!")
identifier = mesh_scan_identifier(phi=phi, theta=theta)
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
if not file.exists(): raise FileNotFoundError(str(file))
return SingleViewUVScan.from_h5_file(file)
def list_mesh_scan_uv_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
if identifiers is None:
identifiers = list_mesh_scan_identifiers(**kw)
return [
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
for obj_name in obj_names
for identifier in identifiers
]
# === sphere-view (UV) scan clouds
def compute_mesh_sphere_scan(
obj_name : str,
*,
compute_normals : bool = True,
**kw,
) -> SingleViewUVScan:
mesh = read_mesh(obj_name)
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
compute_normals = compute_normals,
**kw,
)
return scan
def precompute_mesh_sphere_scan(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
paths = list_mesh_sphere_scan_h5_fnames(obj_names)
@wraps(compute_mesh_sphere_scan)
def computer(obj_name: str) -> SingleViewScan:
return compute_mesh_sphere_scan(obj_name, **kw)
return processing.precompute_data(computer, obj_names, paths, page=page, force=force, debug=debug)
def read_mesh_mesh_sphere_scan(obj_name) -> SingleViewUVScan:
file = config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
if not file.exists(): raise FileNotFoundError(str(file))
return SingleViewUVScan.from_h5_file(file)
def list_mesh_sphere_scan_h5_fnames(obj_names: Iterable[str]) -> list[Path]:
return [
config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
for obj_name in obj_names
]