Add code
This commit is contained in:
3
ifield/data/__init__.py
Normal file
3
ifield/data/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
__doc__ = """
|
||||
Submodules to read and process datasets
|
||||
"""
|
||||
0
ifield/data/common/__init__.py
Normal file
0
ifield/data/common/__init__.py
Normal file
90
ifield/data/common/download.py
Normal file
90
ifield/data/common/download.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Optional
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
|
||||
PathLike = Union[os.PathLike, str]
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
def check_url(url): # HTTP HEAD
|
||||
return requests.head(url).ok
|
||||
|
||||
def download_stream(
|
||||
url : str,
|
||||
file_object,
|
||||
block_size : int = 1024,
|
||||
silent : bool = False,
|
||||
label : Optional[str] = None,
|
||||
):
|
||||
resp = requests.get(url, stream=True)
|
||||
total_size = int(resp.headers.get("content-length", 0))
|
||||
if not silent:
|
||||
progress_bar = tqdm(total=total_size , unit="iB", unit_scale=True, desc=label)
|
||||
|
||||
for chunk in resp.iter_content(block_size):
|
||||
if not silent:
|
||||
progress_bar.update(len(chunk))
|
||||
file_object.write(chunk)
|
||||
|
||||
if not silent:
|
||||
progress_bar.close()
|
||||
if total_size != 0 and progress_bar.n != total_size:
|
||||
print("ERROR, something went wrong")
|
||||
|
||||
def download_data(
|
||||
url : str,
|
||||
block_size : int = 1024,
|
||||
silent : bool = False,
|
||||
label : Optional[str] = None,
|
||||
) -> bytearray:
|
||||
f = io.BytesIO()
|
||||
download_stream(url, f, block_size=block_size, silent=silent, label=label)
|
||||
f.seek(0)
|
||||
return bytearray(f.read())
|
||||
|
||||
def download_file(
|
||||
url : str,
|
||||
fname : Union[Path, str],
|
||||
block_size : int = 1024,
|
||||
silent = False,
|
||||
):
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
with fname.open("wb") as f:
|
||||
download_stream(url, f, block_size=block_size, silent=silent, label=make_relative(fname, Path.cwd()).name)
|
||||
|
||||
def is_downloaded(
|
||||
target_dir : PathLike,
|
||||
url : str,
|
||||
*,
|
||||
add : bool = False,
|
||||
dbfiles : Union[list[PathLike], PathLike],
|
||||
):
|
||||
if not isinstance(target_dir, os.PathLike):
|
||||
target_dir = Path(target_dir)
|
||||
if not isinstance(dbfiles, list):
|
||||
dbfiles = [dbfiles]
|
||||
if not dbfiles:
|
||||
raise ValueError("'dbfiles' empty")
|
||||
downloaded = set()
|
||||
for dbfile_fname in dbfiles:
|
||||
dbfile_fname = target_dir / dbfile_fname
|
||||
if dbfile_fname.is_file():
|
||||
with open(dbfile_fname, "r") as f:
|
||||
downloaded.update(json.load(f)["downloaded"])
|
||||
|
||||
if add and url not in downloaded:
|
||||
downloaded.add(url)
|
||||
with open(dbfiles[0], "w") as f:
|
||||
data = {"downloaded": sorted(downloaded)}
|
||||
json.dump(data, f, indent=2, sort_keys=True)
|
||||
return True
|
||||
|
||||
return url in downloaded
|
||||
370
ifield/data/common/h5_dataclasses.py
Normal file
370
ifield/data/common/h5_dataclasses.py
Normal file
@@ -0,0 +1,370 @@
|
||||
#!/usr/bin/env python3
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import h5py as h5
|
||||
import hdf5plugin
|
||||
import numpy as np
|
||||
import operator
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
|
||||
__all__ = [
|
||||
"DataclassMeta",
|
||||
"Dataclass",
|
||||
"H5Dataclass",
|
||||
"H5Array",
|
||||
"H5ArrayNoSlice",
|
||||
]
|
||||
|
||||
T = typing.TypeVar("T")
|
||||
NoneType = type(None)
|
||||
PathLike = typing.Union[os.PathLike, str]
|
||||
H5Array = typing._alias(np.ndarray, 0, inst=False, name="H5Array")
|
||||
H5ArrayNoSlice = typing._alias(np.ndarray, 0, inst=False, name="H5ArrayNoSlice")
|
||||
|
||||
DataclassField = namedtuple("DataclassField", [
|
||||
"name",
|
||||
"type",
|
||||
"is_optional",
|
||||
"is_array",
|
||||
"is_sliceable",
|
||||
"is_prefix",
|
||||
])
|
||||
|
||||
def strip_optional(val: type) -> type:
|
||||
if typing.get_origin(val) is typing.Union:
|
||||
union = set(typing.get_args(val))
|
||||
if len(union - {NoneType}) == 1:
|
||||
val, = union - {NoneType}
|
||||
else:
|
||||
raise TypeError(f"Non-'typing.Optional' 'typing.Union' is not supported: {typing._type_repr(val)!r}")
|
||||
return val
|
||||
|
||||
def is_array(val, *, _inner=False):
|
||||
"""
|
||||
Hacky way to check if a value or type is an array.
|
||||
The hack omits having to depend on large frameworks such as pytorch or pandas
|
||||
"""
|
||||
val = strip_optional(val)
|
||||
if val is H5Array or val is H5ArrayNoSlice:
|
||||
return True
|
||||
|
||||
if typing._type_repr(val) in (
|
||||
"numpy.ndarray",
|
||||
"torch.Tensor",
|
||||
):
|
||||
return True
|
||||
if not _inner:
|
||||
return is_array(type(val), _inner=True)
|
||||
return False
|
||||
|
||||
def prod(numbers: typing.Iterable[T], initial: typing.Optional[T] = None) -> T:
|
||||
if initial is not None:
|
||||
return functools.reduce(operator.mul, numbers, initial)
|
||||
else:
|
||||
return functools.reduce(operator.mul, numbers)
|
||||
|
||||
class DataclassMeta(type):
|
||||
def __new__(
|
||||
mcls,
|
||||
name : str,
|
||||
bases : tuple[type, ...],
|
||||
attrs : dict[str, typing.Any],
|
||||
**kwargs,
|
||||
):
|
||||
cls = super().__new__(mcls, name, bases, attrs, **kwargs)
|
||||
if sys.version_info[:2] >= (3, 10) and not hasattr(cls, "__slots__"):
|
||||
cls = dataclasses.dataclass(slots=True)(cls)
|
||||
else:
|
||||
cls = dataclasses.dataclass(cls)
|
||||
return cls
|
||||
|
||||
class DataclassABCMeta(DataclassMeta, ABCMeta):
|
||||
pass
|
||||
|
||||
class Dataclass(metaclass=DataclassMeta):
|
||||
def __getitem__(self, key: str) -> typing.Any:
|
||||
if key in self.keys():
|
||||
return getattr(self, key)
|
||||
raise KeyError(key)
|
||||
|
||||
def __setitem__(self, key: str, value: typing.Any):
|
||||
if key in self.keys():
|
||||
return setattr(self, key, value)
|
||||
raise KeyError(key)
|
||||
|
||||
def keys(self) -> typing.KeysView:
|
||||
return self.as_dict().keys()
|
||||
|
||||
def values(self) -> typing.ValuesView:
|
||||
return self.as_dict().values()
|
||||
|
||||
def items(self) -> typing.ItemsView:
|
||||
return self.as_dict().items()
|
||||
|
||||
def as_dict(self, properties_to_include: set[str] = None, **kw) -> dict[str, typing.Any]:
|
||||
out = dataclasses.asdict(self, **kw)
|
||||
for name in (properties_to_include or []):
|
||||
out[name] = getattr(self, name)
|
||||
return out
|
||||
|
||||
def as_tuple(self, properties_to_include: list[str]) -> tuple:
|
||||
out = dataclasses.astuple(self)
|
||||
if not properties_to_include:
|
||||
return out
|
||||
else:
|
||||
return (
|
||||
*out,
|
||||
*(getattr(self, name) for name in properties_to_include),
|
||||
)
|
||||
|
||||
def copy(self: T, *, deep=True) -> T:
|
||||
return (copy.deepcopy if deep else copy.copy)(self)
|
||||
|
||||
class H5Dataclass(Dataclass):
|
||||
# settable with class params:
|
||||
_prefix : str = dataclasses.field(init=False, repr=False, default="")
|
||||
_n_pages : int = dataclasses.field(init=False, repr=False, default=10)
|
||||
_require_all : bool = dataclasses.field(init=False, repr=False, default=False)
|
||||
|
||||
def __init_subclass__(cls,
|
||||
prefix : typing.Optional[str] = None,
|
||||
n_pages : typing.Optional[int] = None,
|
||||
require_all : typing.Optional[bool] = None,
|
||||
**kw,
|
||||
):
|
||||
super().__init_subclass__(**kw)
|
||||
assert dataclasses.is_dataclass(cls)
|
||||
if prefix is not None: cls._prefix = prefix
|
||||
if n_pages is not None: cls._n_pages = n_pages
|
||||
if require_all is not None: cls._require_all = require_all
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls) -> typing.Iterable[DataclassField]:
|
||||
for field in dataclasses.fields(cls):
|
||||
if not field.init:
|
||||
continue
|
||||
assert field.name not in ("_prefix", "_n_pages", "_require_all"), (
|
||||
f"{field.name!r} can not be in {cls.__qualname__}.__init__.\n"
|
||||
"Set it with dataclasses.field(default=YOUR_VALUE, init=False, repr=False)"
|
||||
)
|
||||
if isinstance(field.type, str):
|
||||
raise TypeError("Type hints are strings, perhaps avoid using `from __future__ import annotations`")
|
||||
|
||||
type_inner = strip_optional(field.type)
|
||||
is_prefix = typing.get_origin(type_inner) is dict and typing.get_args(type_inner)[:1] == (str,)
|
||||
field_type = typing.get_args(type_inner)[1] if is_prefix else field.type
|
||||
if field.default is None or typing.get_origin(field.type) is typing.Union and NoneType in typing.get_args(field.type):
|
||||
field_type = typing.Optional[field_type]
|
||||
|
||||
yield DataclassField(
|
||||
name = field.name,
|
||||
type = strip_optional(field_type),
|
||||
is_optional = typing.get_origin(field_type) is typing.Union and NoneType in typing.get_args(field_type),
|
||||
is_array = is_array(field_type),
|
||||
is_sliceable = is_array(field_type) and strip_optional(field_type) is not H5ArrayNoSlice,
|
||||
is_prefix = is_prefix,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_h5_file(cls : type[T],
|
||||
fname : typing.Union[PathLike, str],
|
||||
*,
|
||||
page : typing.Optional[int] = None,
|
||||
n_pages : typing.Optional[int] = None,
|
||||
read_slice : slice = slice(None),
|
||||
require_even_pages : bool = True,
|
||||
) -> T:
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
if n_pages is None:
|
||||
n_pages = cls._n_pages
|
||||
if not fname.exists():
|
||||
raise FileNotFoundError(str(fname))
|
||||
if not h5.is_hdf5(fname):
|
||||
raise TypeError(f"Not a HDF5 file: {str(fname)!r}")
|
||||
|
||||
# if this class has no fields, print a example class:
|
||||
if not any(field.init for field in dataclasses.fields(cls)):
|
||||
with h5.File(fname, "r") as f:
|
||||
klen = max(map(len, f.keys()))
|
||||
example_cls = f"\nclass {cls.__name__}(Dataclass, require_all=True):\n" + "\n".join(
|
||||
f" {k.ljust(klen)} : "
|
||||
+ (
|
||||
"H5Array" if prod(v.shape, 1) > 1 else (
|
||||
"float" if issubclass(v.dtype.type, np.floating) else (
|
||||
"int" if issubclass(v.dtype.type, np.integer) else (
|
||||
"bool" if issubclass(v.dtype.type, np.bool_) else (
|
||||
"typing.Any"
|
||||
))))).ljust(14 + 1)
|
||||
+ f" #{repr(v).split(':', 1)[1].removesuffix('>')}"
|
||||
for k, v in f.items()
|
||||
)
|
||||
raise NotImplementedError(f"{cls!r} has no fields!\nPerhaps try the following:{example_cls}")
|
||||
|
||||
fields_consumed = set()
|
||||
|
||||
def make_kwarg(
|
||||
file : h5.File,
|
||||
keys : typing.KeysView,
|
||||
field : DataclassField,
|
||||
) -> tuple[str, typing.Any]:
|
||||
if field.is_optional:
|
||||
if field.name not in keys:
|
||||
return field.name, None
|
||||
if field.is_sliceable:
|
||||
if page is not None:
|
||||
n_items = int(f[cls._prefix + field.name].shape[0])
|
||||
page_len = n_items // n_pages
|
||||
modulus = n_items % n_pages
|
||||
if modulus: page_len += 1 # round up
|
||||
if require_even_pages and modulus:
|
||||
raise ValueError(f"Field {field.name!r} {tuple(f[cls._prefix + field.name].shape)} is not cleanly divisible into {n_pages} pages")
|
||||
this_slice = slice(
|
||||
start = page_len * page,
|
||||
stop = page_len * (page+1),
|
||||
step = read_slice.step, # inherit step
|
||||
)
|
||||
else:
|
||||
this_slice = read_slice
|
||||
else:
|
||||
this_slice = slice(None) # read all
|
||||
|
||||
# array or scalar?
|
||||
def read_dataset(var):
|
||||
# https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
|
||||
if field.is_array:
|
||||
return var[this_slice]
|
||||
if var.shape == (1,):
|
||||
return var[0]
|
||||
else:
|
||||
return var[()]
|
||||
|
||||
if field.is_prefix:
|
||||
fields_consumed.update(
|
||||
key
|
||||
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
||||
)
|
||||
return field.name, {
|
||||
key.removeprefix(f"{cls._prefix}{field.name}_") : read_dataset(file[key])
|
||||
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
||||
}
|
||||
else:
|
||||
fields_consumed.add(cls._prefix + field.name)
|
||||
return field.name, read_dataset(file[cls._prefix + field.name])
|
||||
|
||||
with h5.File(fname, "r") as f:
|
||||
keys = f.keys()
|
||||
init_dict = dict( make_kwarg(f, keys, i) for i in cls._get_fields() )
|
||||
|
||||
try:
|
||||
out = cls(**init_dict)
|
||||
except Exception as e:
|
||||
class_attrs = set(field.name for field in dataclasses.fields(cls))
|
||||
file_attr = set(init_dict.keys())
|
||||
raise e.__class__(f"{e}. {class_attrs=}, {file_attr=}, diff={class_attrs.symmetric_difference(file_attr)}") from e
|
||||
|
||||
if cls._require_all:
|
||||
fields_not_consumed = set(keys) - fields_consumed
|
||||
if fields_not_consumed:
|
||||
raise ValueError(f"Not all HDF5 fields consumed: {fields_not_consumed!r}")
|
||||
|
||||
return out
|
||||
|
||||
def to_h5_file(self,
|
||||
fname : PathLike,
|
||||
mkdir : bool = False,
|
||||
):
|
||||
if not isinstance(fname, Path):
|
||||
fname = Path(fname)
|
||||
if not fname.parent.is_dir():
|
||||
if mkdir:
|
||||
fname.parent.mkdir(parents=True)
|
||||
else:
|
||||
raise NotADirectoryError(fname.parent)
|
||||
|
||||
with h5.File(fname, "w") as f:
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_optional and getattr(self, field.name) is None:
|
||||
continue
|
||||
value = getattr(self, field.name)
|
||||
if field.is_array:
|
||||
if any(type(i) is not np.ndarray for i in (value.values() if field.is_prefix else [value])):
|
||||
raise TypeError(
|
||||
"When dumping a H5Dataclass, make sure the array fields are "
|
||||
f"numpy arrays (the type of {field.name!r} is {typing._type_repr(type(value))}).\n"
|
||||
"Example: h5dataclass.map_arrays(torch.Tensor.numpy)"
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
def write_value(key: str, value: typing.Any):
|
||||
if field.is_array:
|
||||
f.create_dataset(key, data=value, **hdf5plugin.LZ4())
|
||||
else:
|
||||
f.create_dataset(key, data=value)
|
||||
|
||||
if field.is_prefix:
|
||||
for k, v in value.items():
|
||||
write_value(self._prefix + field.name + "_" + k, v)
|
||||
else:
|
||||
write_value(self._prefix + field.name, value)
|
||||
|
||||
def map_arrays(self: T, func: typing.Callable[[H5Array], H5Array], do_copy: bool = False) -> T:
|
||||
if do_copy: # shallow
|
||||
self = self.copy(deep=False)
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_optional and getattr(self, field.name) is None:
|
||||
continue
|
||||
if field.is_prefix and field.is_array:
|
||||
setattr(self, field.name, {
|
||||
k : func(v)
|
||||
for k, v in getattr(self, field.name).items()
|
||||
})
|
||||
elif field.is_array:
|
||||
setattr(self, field.name, func(getattr(self, field.name)))
|
||||
|
||||
return self
|
||||
|
||||
def astype(self: T, t: type, do_copy: bool = False, convert_nonfloats: bool = False) -> T:
|
||||
return self.map_arrays(lambda x: x.astype(t) if convert_nonfloats or not np.issubdtype(x.dtype, int) else x)
|
||||
|
||||
def copy(self: T, *, deep=True) -> T:
|
||||
out = super().copy(deep=deep)
|
||||
if not deep:
|
||||
for field in type(self)._get_fields():
|
||||
if field.is_prefix:
|
||||
out[field.name] = copy.copy(field.name)
|
||||
return out
|
||||
|
||||
@property
|
||||
def shape(self) -> dict[str, tuple[int, ...]]:
|
||||
return {
|
||||
key: value.shape
|
||||
for key, value in self.items()
|
||||
if hasattr(value, "shape")
|
||||
}
|
||||
|
||||
class TransformableDataclassMixin(metaclass=DataclassABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
def transform(self: T, mat4: np.ndarray, inplace=False) -> T:
|
||||
...
|
||||
|
||||
def transform_to(self: T, name: str, inverse_name: str = None, *, inplace=False) -> T:
|
||||
mtx = self.transforms[name]
|
||||
out = self.transform(mtx, inplace=inplace)
|
||||
out.transforms.pop(name) # consumed
|
||||
|
||||
inv = np.linalg.inv(mtx)
|
||||
for key in list(out.transforms.keys()): # maintain the other transforms
|
||||
out.transforms[key] = out.transforms[key] @ inv
|
||||
if inverse_name is not None: # store inverse
|
||||
out.transforms[inverse_name] = inv
|
||||
|
||||
return out
|
||||
48
ifield/data/common/mesh.py
Normal file
48
ifield/data/common/mesh.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from math import pi
|
||||
from trimesh import Trimesh
|
||||
import numpy as np
|
||||
import os
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
def rotate_to_closest_axis_aligned_bounds(
|
||||
mesh : Trimesh,
|
||||
order_axes : bool = True,
|
||||
fail_ok : bool = True,
|
||||
) -> np.ndarray:
|
||||
to_origin_mat4, extents = trimesh.bounds.oriented_bounds(mesh, ordered=not order_axes)
|
||||
to_aabb_rot_mat4 = T.euler_matrix(*T.decompose_matrix(to_origin_mat4)[3])
|
||||
|
||||
if not order_axes:
|
||||
return to_aabb_rot_mat4
|
||||
|
||||
v = pi / 4 * 1.01 # tolerance
|
||||
v2 = pi / 2
|
||||
|
||||
faces = (
|
||||
(0, 0),
|
||||
(1, 0),
|
||||
(2, 0),
|
||||
(3, 0),
|
||||
(0, 1),
|
||||
(0,-1),
|
||||
)
|
||||
orientations = [ # 6 faces x 4 rotations per face
|
||||
(f[0] * v2, f[1] * v2, i * v2)
|
||||
for i in range(4)
|
||||
for f in faces]
|
||||
|
||||
for x, y, z in orientations:
|
||||
mat4 = T.euler_matrix(x, y, z) @ to_aabb_rot_mat4
|
||||
ai, aj, ak = T.euler_from_matrix(mat4)
|
||||
if abs(ai) <= v and abs(aj) <= v and abs(ak) <= v:
|
||||
return mat4
|
||||
|
||||
if fail_ok: return to_aabb_rot_mat4
|
||||
raise Exception("Unable to orient mesh")
|
||||
297
ifield/data/common/points.py
Normal file
297
ifield/data/common/points.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from __future__ import annotations
|
||||
from ...utils.helpers import compose
|
||||
from functools import reduce, lru_cache
|
||||
from math import ceil
|
||||
from typing import Iterable
|
||||
import numpy as np
|
||||
import operator
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
|
||||
def img2col(img: np.ndarray, psize: int) -> np.ndarray:
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
|
||||
n_channels = 1 if len(img.shape) == 2 else img.shape[0]
|
||||
n_channels, rows, cols = (1,) * (3 - len(img.shape)) + img.shape
|
||||
|
||||
# pad the image
|
||||
img_pad = np.zeros((
|
||||
n_channels,
|
||||
int(ceil(1.0 * rows / psize) * psize),
|
||||
int(ceil(1.0 * cols / psize) * psize),
|
||||
))
|
||||
img_pad[:, 0:rows, 0:cols] = img
|
||||
|
||||
# allocate output buffer
|
||||
final = np.zeros((
|
||||
img_pad.shape[1],
|
||||
img_pad.shape[2],
|
||||
n_channels,
|
||||
psize,
|
||||
psize,
|
||||
))
|
||||
|
||||
for c in range(n_channels):
|
||||
for x in range(psize):
|
||||
for y in range(psize):
|
||||
img_shift = np.vstack((
|
||||
img_pad[c, x:],
|
||||
img_pad[c, :x]))
|
||||
img_shift = np.column_stack((
|
||||
img_shift[:, y:],
|
||||
img_shift[:, :y]))
|
||||
final[x::psize, y::psize, c] = np.swapaxes(
|
||||
img_shift.reshape(
|
||||
int(img_pad.shape[1] / psize), psize,
|
||||
int(img_pad.shape[2] / psize), psize),
|
||||
1,
|
||||
2)
|
||||
|
||||
# crop output and unwrap axes with size==1
|
||||
return np.squeeze(final[
|
||||
0:rows - psize + 1,
|
||||
0:cols - psize + 1])
|
||||
|
||||
def filter_depth_discontinuities(depth_map: np.ndarray, filt_size = 7, thresh = 1000) -> np.ndarray:
|
||||
"""
|
||||
Removes data close to discontinuities, with size filt_size.
|
||||
"""
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
|
||||
# Ensure that filter sizes are okay
|
||||
assert filt_size % 2, "Can only use odd filter sizes."
|
||||
|
||||
# Compute discontinuities
|
||||
offset = int(filt_size - 1) // 2
|
||||
patches = 1.0 * img2col(depth_map, filt_size)
|
||||
mids = patches[:, :, offset, offset]
|
||||
mins = np.min(patches, axis=(2, 3))
|
||||
maxes = np.max(patches, axis=(2, 3))
|
||||
|
||||
discont = np.maximum(
|
||||
np.abs(mins - mids),
|
||||
np.abs(maxes - mids))
|
||||
mark = discont > thresh
|
||||
|
||||
# Account for offsets
|
||||
final_mark = np.zeros(depth_map.shape, dtype=np.uint16)
|
||||
final_mark[offset:offset + mark.shape[0],
|
||||
offset:offset + mark.shape[1]] = mark
|
||||
|
||||
return depth_map * (1 - final_mark)
|
||||
|
||||
def reorient_depth_map(
|
||||
depth_map : np.ndarray,
|
||||
rgb_map : np.ndarray,
|
||||
depth_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
|
||||
depth_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
|
||||
rgb_mat3 : np.ndarray, # 3x3 intrinsic camera matrix
|
||||
rgb_vec5 : np.ndarray, # 5 distortion parameters (k1, k2, p1, p2, k3)
|
||||
ir_to_rgb_mat4 : np.ndarray, # extrinsic transformation matrix from depth to rgb camera viewpoint
|
||||
rgb_mask_map : np.ndarray = None,
|
||||
_output_points = False, # retval (H, W) if false else (N, XYZRGB)
|
||||
_output_hits_uvs = False, # retval[1] is dtype=bool of hits shaped like depth_map
|
||||
) -> np.ndarray:
|
||||
|
||||
"""
|
||||
Corrects depth_map to be from the same view as the rgb_map, with the same dimensions.
|
||||
If _output_points is True, the points returned are in the rgb camera space.
|
||||
"""
|
||||
# based of ycb_generate_point_cloud.py provided by YCB
|
||||
# now faster AND more easy on the GIL
|
||||
|
||||
height_old, width_old, *_ = depth_map.shape
|
||||
height, width, *_ = rgb_map.shape
|
||||
|
||||
|
||||
d_cx, r_cx = depth_mat3[0, 2], rgb_mat3[0, 2] # optical center
|
||||
d_cy, r_cy = depth_mat3[1, 2], rgb_mat3[1, 2]
|
||||
d_fx, r_fx = depth_mat3[0, 0], rgb_mat3[0, 0] # focal length
|
||||
d_fy, r_fy = depth_mat3[1, 1], rgb_mat3[1, 1]
|
||||
d_k1, d_k2, d_p1, d_p2, d_k3 = depth_vec5
|
||||
c_k1, c_k2, c_p1, c_p2, c_k3 = rgb_vec5
|
||||
|
||||
# make a UV grid over depth_map
|
||||
u, v = np.meshgrid(
|
||||
np.arange(width_old),
|
||||
np.arange(height_old),
|
||||
)
|
||||
|
||||
# compute xyz coordinates for all depths
|
||||
xyz_depth = np.stack((
|
||||
(u - d_cx) / d_fx,
|
||||
(v - d_cy) / d_fy,
|
||||
depth_map,
|
||||
np.ones(depth_map.shape)
|
||||
)).reshape((4, -1))
|
||||
xyz_depth = xyz_depth[:, xyz_depth[2] != 0]
|
||||
|
||||
# undistort depth coordinates
|
||||
d_x, d_y = xyz_depth[:2]
|
||||
r = np.linalg.norm(xyz_depth[:2], axis=0)
|
||||
xyz_depth[0, :] \
|
||||
= d_x / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
|
||||
- (2*d_p1*d_x*d_y + d_p2*(r**2 + 2*d_x**2))
|
||||
xyz_depth[1, :] \
|
||||
= d_y / (1 + d_k1*r**2 + d_k2*r**4 + d_k3*r**6) \
|
||||
- (d_p1*(r**2 + 2*d_y**2) + 2*d_p2*d_x*d_y)
|
||||
|
||||
# unproject x and y
|
||||
xyz_depth[0, :] *= xyz_depth[2, :]
|
||||
xyz_depth[1, :] *= xyz_depth[2, :]
|
||||
|
||||
# convert depths to RGB camera viewpoint
|
||||
xyz_rgb = ir_to_rgb_mat4 @ xyz_depth
|
||||
|
||||
# project depths to RGB canvas
|
||||
rgb_z_inv = 1 / xyz_rgb[2] # perspective correction
|
||||
rgb_uv = np.stack((
|
||||
xyz_rgb[0] * rgb_z_inv * r_fx + r_cx + 0.5,
|
||||
xyz_rgb[1] * rgb_z_inv * r_fy + r_cy + 0.5,
|
||||
)).astype(np.int)
|
||||
|
||||
# mask of the rgb_xyz values within view of rgb_map
|
||||
mask = reduce(operator.and_, [
|
||||
rgb_uv[0] >= 0,
|
||||
rgb_uv[1] >= 0,
|
||||
rgb_uv[0] < width,
|
||||
rgb_uv[1] < height,
|
||||
])
|
||||
if rgb_mask_map is not None:
|
||||
mask[mask] &= rgb_mask_map[
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask]]
|
||||
|
||||
if not _output_points: # output image
|
||||
output = np.zeros((height, width), dtype=depth_map.dtype)
|
||||
output[
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask],
|
||||
] = xyz_rgb[2, mask]
|
||||
|
||||
else: # output pointcloud
|
||||
rgbs = rgb_map[ # lookup rgb values using rgb_uv
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask]]
|
||||
output = np.stack((
|
||||
xyz_rgb[0, mask], # x
|
||||
xyz_rgb[1, mask], # y
|
||||
xyz_rgb[2, mask], # z
|
||||
rgbs[:, 0], # r
|
||||
rgbs[:, 1], # g
|
||||
rgbs[:, 2], # b
|
||||
)).T
|
||||
|
||||
# output for realsies
|
||||
if not _output_hits_uvs: #raw
|
||||
return output
|
||||
else: # with hit mask
|
||||
uv = np.zeros((height, width), dtype=bool)
|
||||
# filter points overlapping in the depth map
|
||||
uv_indices = (
|
||||
rgb_uv[1, mask],
|
||||
rgb_uv[0, mask],
|
||||
)
|
||||
_, chosen = np.unique( uv_indices[0] << 32 | uv_indices[1], return_index=True )
|
||||
output = output[chosen, :]
|
||||
uv[uv_indices] = True
|
||||
return output, uv
|
||||
|
||||
def join_rgb_and_depth_to_points(*a, **kw) -> np.ndarray:
|
||||
return reorient_depth_map(*a, _output_points=True, **kw)
|
||||
|
||||
@compose(np.array) # block lru cache mutation
|
||||
@lru_cache(maxsize=1)
|
||||
@compose(list)
|
||||
def generate_equidistant_sphere_points(
|
||||
n : int,
|
||||
centroid : np.ndarray = (0, 0, 0),
|
||||
radius : float = 1,
|
||||
compute_sphere_coordinates : bool = False,
|
||||
compute_normals : bool = False,
|
||||
shift_theta : bool = False,
|
||||
) -> Iterable[tuple[float, ...]]:
|
||||
# Deserno M. How to generate equidistributed points on the surface of a sphere
|
||||
# https://www.cmu.edu/biolphys/deserno/pdf/sphere_equi.pdf
|
||||
|
||||
if compute_sphere_coordinates and compute_normals:
|
||||
raise ValueError(
|
||||
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
|
||||
)
|
||||
|
||||
n_count = 0
|
||||
a = 4 * np.pi / n
|
||||
d = np.sqrt(a)
|
||||
n_theta = round(np.pi / d)
|
||||
d_theta = np.pi / n_theta
|
||||
d_phi = a / d_theta
|
||||
|
||||
for i in range(0, n_theta):
|
||||
theta = np.pi * (i + 0.5) / n_theta
|
||||
n_phi = round(2 * np.pi * np.sin(theta) / d_phi)
|
||||
|
||||
for j in range(0, n_phi):
|
||||
phi = 2 * np.pi * j / n_phi
|
||||
|
||||
if compute_sphere_coordinates: # (theta, phi)
|
||||
yield (
|
||||
theta if shift_theta else theta - 0.5*np.pi,
|
||||
phi,
|
||||
)
|
||||
elif compute_normals: # (x, y, z, nx, ny, nz)
|
||||
yield (
|
||||
centroid[0] + radius * np.sin(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.sin(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.cos(theta),
|
||||
np.sin(theta) * np.cos(phi),
|
||||
np.sin(theta) * np.sin(phi),
|
||||
np.cos(theta),
|
||||
)
|
||||
else: # (x, y, z)
|
||||
yield (
|
||||
centroid[0] + radius * np.sin(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.sin(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.cos(theta),
|
||||
)
|
||||
n_count += 1
|
||||
|
||||
|
||||
def generate_random_sphere_points(
|
||||
n : int,
|
||||
centroid : np.ndarray = (0, 0, 0),
|
||||
radius : float = 1,
|
||||
compute_sphere_coordinates : bool = False,
|
||||
compute_normals : bool = False,
|
||||
shift_theta : bool = False, # depends on convention
|
||||
) -> np.ndarray:
|
||||
if compute_sphere_coordinates and compute_normals:
|
||||
raise ValueError(
|
||||
"'compute_sphere_coordinates' and 'compute_normals' are mutually exclusive"
|
||||
)
|
||||
|
||||
theta = np.arcsin(np.random.uniform(-1, 1, n)) # inverse transform sampling
|
||||
phi = np.random.uniform(0, 2*np.pi, n)
|
||||
|
||||
if compute_sphere_coordinates: # (theta, phi)
|
||||
return np.stack((
|
||||
theta if not shift_theta else 0.5*np.pi + theta,
|
||||
phi,
|
||||
), axis=1)
|
||||
elif compute_normals: # (x, y, z, nx, ny, nz)
|
||||
return np.stack((
|
||||
centroid[0] + radius * np.cos(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.cos(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.sin(theta),
|
||||
np.cos(theta) * np.cos(phi),
|
||||
np.cos(theta) * np.sin(phi),
|
||||
np.sin(theta),
|
||||
), axis=1)
|
||||
else: # (x, y, z)
|
||||
return np.stack((
|
||||
centroid[0] + radius * np.cos(theta) * np.cos(phi),
|
||||
centroid[1] + radius * np.cos(theta) * np.sin(phi),
|
||||
centroid[2] + radius * np.sin(theta),
|
||||
), axis=1)
|
||||
85
ifield/data/common/processing.py
Normal file
85
ifield/data/common/processing.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from .h5_dataclasses import H5Dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Hashable, Optional, Callable
|
||||
import os
|
||||
|
||||
DEBUG = bool(os.environ.get("IFIELD_DEBUG", ""))
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper functions for processing data.
|
||||
"""
|
||||
|
||||
# multiprocessing does not work due to my rediculous use of closures, which seemingly cannot be pickled
|
||||
# paralelize it in the shell instead
|
||||
|
||||
def precompute_data(
|
||||
computer : Callable[[Hashable], Optional[H5Dataclass]],
|
||||
identifiers : list[Hashable],
|
||||
output_paths : list[Path],
|
||||
page : tuple[int, int] = (0, 1),
|
||||
*,
|
||||
force : bool = False,
|
||||
debug : bool = False,
|
||||
):
|
||||
"""
|
||||
precomputes data and stores them as HDF5 datasets using `.to_file(path: Path)`
|
||||
"""
|
||||
|
||||
page, n_pages = page
|
||||
assert len(identifiers) == len(output_paths)
|
||||
|
||||
total = len(identifiers)
|
||||
identifier_max_len = max(map(len, map(str, identifiers)))
|
||||
t_epoch = None
|
||||
def log(state: str, is_start = False):
|
||||
nonlocal t_epoch
|
||||
if is_start: t_epoch = datetime.now()
|
||||
td = timedelta(0) if is_start else datetime.now() - t_epoch
|
||||
print(" - "
|
||||
f"{str(index+1).rjust(len(str(total)))}/{total}: "
|
||||
f"{str(identifier).ljust(identifier_max_len)} @ {td}: {state}"
|
||||
)
|
||||
|
||||
print(f"precompute_data(computer={computer.__module__}.{computer.__qualname__}, identifiers=..., force={force}, page={page})")
|
||||
t_begin = datetime.now()
|
||||
failed = []
|
||||
|
||||
# pagination
|
||||
page_size = total // n_pages + bool(total % n_pages)
|
||||
jobs = list(zip(identifiers, output_paths))[page_size*page : page_size*(page+1)]
|
||||
|
||||
for index, (identifier, output_path) in enumerate(jobs, start=page_size*page):
|
||||
if not force and output_path.exists() and output_path.stat().st_size > 0:
|
||||
continue
|
||||
|
||||
log("compute", is_start=True)
|
||||
|
||||
# compute
|
||||
try:
|
||||
res = computer(identifier)
|
||||
except Exception as e:
|
||||
failed.append(identifier)
|
||||
log(f"failed compute: {e.__class__.__name__}: {e}")
|
||||
if DEBUG or debug: raise e
|
||||
continue
|
||||
if res is None:
|
||||
failed.append(identifier)
|
||||
log("no result")
|
||||
continue
|
||||
|
||||
# write to file
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
res.to_h5_file(output_path)
|
||||
except Exception as e:
|
||||
failed.append(identifier)
|
||||
log(f"failed write: {e.__class__.__name__}: {e}")
|
||||
if output_path.is_file(): output_path.unlink() # cleanup
|
||||
if DEBUG or debug: raise e
|
||||
continue
|
||||
|
||||
log("done")
|
||||
|
||||
print("precompute_data finished in", datetime.now() - t_begin)
|
||||
print("failed:", failed or None)
|
||||
768
ifield/data/common/scan.py
Normal file
768
ifield/data/common/scan.py
Normal file
@@ -0,0 +1,768 @@
|
||||
from ...utils.helpers import compose
|
||||
from . import points
|
||||
from .h5_dataclasses import H5Dataclass, H5Array, H5ArrayNoSlice, TransformableDataclassMixin
|
||||
from methodtools import lru_cache
|
||||
from sklearn.neighbors import BallTree
|
||||
import faiss
|
||||
from trimesh import Trimesh
|
||||
from typing import Iterable
|
||||
from typing import Optional, TypeVar
|
||||
import mesh_to_sdf
|
||||
import mesh_to_sdf.scan as sdf_scan
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
import warnings
|
||||
|
||||
__doc__ = """
|
||||
Here are some helper types for data.
|
||||
"""
|
||||
|
||||
_T = TypeVar("T")
|
||||
|
||||
class InvalidateLRUOnWriteMixin:
|
||||
def __setattr__(self, key, value):
|
||||
if not key.startswith("__wire|"):
|
||||
for attr in dir(self):
|
||||
if attr.startswith("__wire|"):
|
||||
getattr(self, attr).cache_clear()
|
||||
return super().__setattr__(key, value)
|
||||
def lru_property(func):
|
||||
return lru_cache(maxsize=1)(property(func))
|
||||
|
||||
class SingleViewScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
|
||||
points_hit : H5ArrayNoSlice # (N, 3)
|
||||
normals_hit : Optional[H5ArrayNoSlice] # (N, 3)
|
||||
points_miss : H5ArrayNoSlice # (M, 3)
|
||||
distances_miss : Optional[H5ArrayNoSlice] # (M)
|
||||
colors_hit : Optional[H5ArrayNoSlice] # (N, 3)
|
||||
colors_miss : Optional[H5ArrayNoSlice] # (M, 3)
|
||||
uv_hits : Optional[H5ArrayNoSlice] # (H, W) dtype=bool
|
||||
uv_miss : Optional[H5ArrayNoSlice] # (H, W) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backfaces)
|
||||
cam_pos : H5ArrayNoSlice # (3)
|
||||
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
|
||||
|
||||
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
|
||||
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
|
||||
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
|
||||
|
||||
out = self if inplace else self.copy(deep=False)
|
||||
out.points_hit = T.transform_points(self.points_hit, mat4)
|
||||
out.normals_hit = T.transform_points(self.normals_hit, mat4) if self.normals_hit is not None else None
|
||||
out.points_miss = T.transform_points(self.points_miss, mat4)
|
||||
out.distances_miss = self.distances_miss * scale_xyz
|
||||
out.cam_pos = T.transform_points(self.points_cam, mat4)[-1]
|
||||
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
|
||||
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
|
||||
return out
|
||||
|
||||
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
|
||||
assert not self.has_miss_distances
|
||||
if not self.is_hitting:
|
||||
raise ValueError("No hits to compute the ray distance towards")
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.distances_miss \
|
||||
= distance_from_rays_to_point_cloud(
|
||||
ray_origins = out.points_cam,
|
||||
ray_dirs = out.ray_dirs_miss,
|
||||
points = out.points_hit,
|
||||
).astype(out.points_cam.dtype)
|
||||
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def points(self) -> np.ndarray: # (N+M+1, 3)
|
||||
return np.concatenate((
|
||||
self.points_hit,
|
||||
self.points_miss,
|
||||
self.points_cam,
|
||||
))
|
||||
|
||||
@lru_property
|
||||
def uv_points(self) -> np.ndarray: # (N+M+1, 3)
|
||||
if not self.has_uv: raise ValueError
|
||||
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.points_hit.dtype)
|
||||
out[self.uv_hits, :] = self.points_hit
|
||||
out[self.uv_miss, :] = self.points_miss
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def uv_normals(self) -> np.ndarray: # (N+M+1, 3)
|
||||
if not self.has_uv: raise ValueError
|
||||
out = np.full((*self.uv_hits.shape, 3), np.nan, dtype=self.normals_hit.dtype)
|
||||
out[self.uv_hits, :] = self.normals_hit
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def points_cam(self) -> Optional[np.ndarray]: # (1, 3)
|
||||
if self.cam_pos is None: return None
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def points_hit_centroid(self) -> np.ndarray:
|
||||
return self.points_hit.mean(axis=0)
|
||||
|
||||
@lru_property
|
||||
def points_hit_std(self) -> np.ndarray:
|
||||
return self.points_hit.std(axis=0)
|
||||
|
||||
@lru_property
|
||||
def is_hitting(self) -> bool:
|
||||
return len(self.points_hit) > 0
|
||||
|
||||
@lru_property
|
||||
def is_empty(self) -> bool:
|
||||
return not (len(self.points_hit) or len(self.points_miss))
|
||||
|
||||
@lru_property
|
||||
def has_colors(self) -> bool:
|
||||
return self.colors_hit is not None or self.colors_miss is not None
|
||||
|
||||
@lru_property
|
||||
def has_normals(self) -> bool:
|
||||
return self.normals_hit is not None
|
||||
|
||||
@lru_property
|
||||
def has_uv(self) -> bool:
|
||||
return self.uv_hits is not None
|
||||
|
||||
@lru_property
|
||||
def has_miss_distances(self) -> bool:
|
||||
return self.distances_miss is not None
|
||||
|
||||
@lru_property
|
||||
def xyzrgb_hit(self) -> np.ndarray: # (N, 6)
|
||||
if self.colors_hit is None: raise ValueError
|
||||
return np.concatenate([self.points_hit, self.colors_hit], axis=1)
|
||||
|
||||
@lru_property
|
||||
def xyzrgb_miss(self) -> np.ndarray: # (M, 6)
|
||||
if self.colors_miss is None: raise ValueError
|
||||
return np.concatenate([self.points_miss, self.colors_miss], axis=1)
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_hit(self) -> np.ndarray: # (N, 3)
|
||||
out = self.points_hit - self.points_cam
|
||||
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_miss(self) -> np.ndarray: # (N, 3)
|
||||
out = self.points_miss - self.points_cam
|
||||
out /= np.linalg.norm(out, axis=-1)[:, None] # normalize
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewScan":
|
||||
if "phi" not in kw and not "theta" in kw:
|
||||
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
scan = sample_single_view_scan_from_mesh(mesh, **kw)
|
||||
if compute_miss_distances and scan.is_hitting:
|
||||
scan.compute_miss_distances()
|
||||
return scan
|
||||
|
||||
def to_uv_scan(self) -> "SingleViewUVScan":
|
||||
return SingleViewUVScan.from_scan(self)
|
||||
|
||||
@classmethod
|
||||
def from_uv_scan(self, uvscan: "SingleViewUVScan") -> "SingleViewUVScan":
|
||||
return uvscan.to_scan()
|
||||
|
||||
# The same, but with support for pagination (should have been this way since the start...)
|
||||
class SingleViewUVScan(H5Dataclass, TransformableDataclassMixin, InvalidateLRUOnWriteMixin, require_all=True):
|
||||
# B may be (N) or (H, W), the latter may be flattened
|
||||
hits : H5Array # (*B) dtype=bool
|
||||
miss : H5Array # (*B) dtype=bool (the reason we store both is due to missing data depth sensor data or filtered backface hits)
|
||||
points : H5Array # (*B, 3) on far plane if miss, NaN if neither hit or miss
|
||||
normals : Optional[H5Array] # (*B, 3) NaN if not hit
|
||||
colors : Optional[H5Array] # (*B, 3)
|
||||
distances : Optional[H5Array] # (*B) NaN if not miss
|
||||
cam_pos : Optional[H5ArrayNoSlice] # (3) or (*B, 3)
|
||||
cam_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
proj_mat4 : Optional[H5ArrayNoSlice] # (4, 4)
|
||||
transforms : dict[str, H5ArrayNoSlice] # a map of 4x4 transformation matrices
|
||||
|
||||
@classmethod
|
||||
def from_scan(cls, scan: SingleViewScan):
|
||||
if not scan.has_uv:
|
||||
raise ValueError("Scan cloud has no UV data")
|
||||
hits, miss = scan.uv_hits, scan.uv_miss
|
||||
dtype = scan.points_hit.dtype
|
||||
assert hits.ndim in (1, 2), hits.ndim
|
||||
assert hits.shape == miss.shape, (hits.shape, miss.shape)
|
||||
|
||||
points = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
points[hits, :] = scan.points_hit
|
||||
points[miss, :] = scan.points_miss
|
||||
|
||||
normals = None
|
||||
if scan.has_normals:
|
||||
normals = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
normals[hits, :] = scan.normals_hit
|
||||
|
||||
distances = None
|
||||
if scan.has_miss_distances:
|
||||
distances = np.full(hits.shape, np.nan, dtype=dtype)
|
||||
distances[miss] = scan.distances_miss
|
||||
|
||||
colors = None
|
||||
if scan.has_colors:
|
||||
colors = np.full((*hits.shape, 3), np.nan, dtype=dtype)
|
||||
if scan.colors_hit is not None:
|
||||
colors[hits, :] = scan.colors_hit
|
||||
if scan.colors_miss is not None:
|
||||
colors[miss, :] = scan.colors_miss
|
||||
|
||||
return cls(
|
||||
hits = hits,
|
||||
miss = miss,
|
||||
points = points,
|
||||
normals = normals,
|
||||
colors = colors,
|
||||
distances = distances,
|
||||
cam_pos = scan.cam_pos,
|
||||
cam_mat4 = scan.cam_mat4,
|
||||
proj_mat4 = scan.proj_mat4,
|
||||
transforms = scan.transforms,
|
||||
)
|
||||
|
||||
def to_scan(self) -> "SingleViewScan":
|
||||
if not self.is_single_view: raise ValueError
|
||||
return SingleViewScan(
|
||||
points_hit = self.points [self.hits, :],
|
||||
points_miss = self.points [self.miss, :],
|
||||
normals_hit = self.normals [self.hits, :] if self.has_normals else None,
|
||||
distances_miss = self.distances[self.miss] if self.has_miss_distances else None,
|
||||
colors_hit = self.colors [self.hits, :] if self.has_colors else None,
|
||||
colors_miss = self.colors [self.miss, :] if self.has_colors else None,
|
||||
uv_hits = self.hits,
|
||||
uv_miss = self.miss,
|
||||
cam_pos = self.cam_pos,
|
||||
cam_mat4 = self.cam_mat4,
|
||||
proj_mat4 = self.proj_mat4,
|
||||
transforms = self.transforms,
|
||||
)
|
||||
|
||||
def to_mesh(self) -> trimesh.Trimesh:
|
||||
faces: list[(tuple[int, int],)*3] = []
|
||||
for x in range(self.hits.shape[0]-1):
|
||||
for y in range(self.hits.shape[1]-1):
|
||||
c11 = x, y
|
||||
c12 = x, y+1
|
||||
c22 = x+1, y+1
|
||||
c21 = x+1, y
|
||||
|
||||
n = sum(map(self.hits.__getitem__, (c11, c12, c22, c21)))
|
||||
if n == 3:
|
||||
faces.append((*filter(self.hits.__getitem__, (c11, c12, c22, c21)),))
|
||||
elif n == 4:
|
||||
faces.append((c11, c12, c22))
|
||||
faces.append((c11, c22, c21))
|
||||
xy2idx = {c:i for i, c in enumerate(set(k for j in faces for k in j))}
|
||||
assert self.colors is not None
|
||||
return trimesh.Trimesh(
|
||||
vertices = [self.points[i] for i in xy2idx.keys()],
|
||||
vertex_colors = [self.colors[i] for i in xy2idx.keys()] if self.colors is not None else None,
|
||||
faces = [tuple(xy2idx[i] for i in face) for face in faces],
|
||||
)
|
||||
|
||||
def transform(self: _T, mat4: np.ndarray, inplace=False) -> _T:
|
||||
scale_xyz = mat4[:3, :3].sum(axis=0) # https://math.stackexchange.com/a/1463487
|
||||
assert all(scale_xyz - scale_xyz[0] < 1e-8), f"differenty scaled axes: {scale_xyz}"
|
||||
|
||||
unflat = self.hits.shape
|
||||
flat = np.product(unflat)
|
||||
|
||||
out = self if inplace else self.copy(deep=False)
|
||||
out.points = T.transform_points(self.points .reshape((*flat, 3)), mat4).reshape((*unflat, 3))
|
||||
out.normals = T.transform_points(self.normals.reshape((*flat, 3)), mat4).reshape((*unflat, 3)) if self.normals_hit is not None else None
|
||||
out.distances = self.distances_miss * scale_xyz
|
||||
out.cam_pos = T.transform_points(self.cam_pos[None, ...], mat4)[0]
|
||||
out.cam_mat4 = (mat4 @ self.cam_mat4) if self.cam_mat4 is not None else None
|
||||
out.proj_mat4 = (mat4 @ self.proj_mat4) if self.proj_mat4 is not None else None
|
||||
return out
|
||||
|
||||
def compute_miss_distances(self: _T, *, copy: bool = False, deep: bool = False, surface_points: Optional[np.ndarray] = None) -> _T:
|
||||
assert not self.has_miss_distances
|
||||
|
||||
shape = self.hits.shape
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.distances = np.zeros(shape, dtype=self.points.dtype)
|
||||
if self.is_hitting:
|
||||
out.distances[self.miss] \
|
||||
= distance_from_rays_to_point_cloud(
|
||||
ray_origins = self.cam_pos_unsqueezed_miss,
|
||||
ray_dirs = self.ray_dirs_miss,
|
||||
points = surface_points if surface_points is not None else self.points[self.hits],
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def fill_missing_points(self: _T, *, copy: bool = False, deep: bool = False) -> _T:
|
||||
"""
|
||||
Fill in missing points as hitting the far plane.
|
||||
"""
|
||||
if not self.is_2d:
|
||||
raise ValueError("Cannot fill missing points for non-2d scan!")
|
||||
if not self.is_single_view:
|
||||
raise ValueError("Cannot fill missing points for non-single-view scans!")
|
||||
if self.cam_mat4 is None:
|
||||
raise ValueError("cam_mat4 is None")
|
||||
if self.proj_mat4 is None:
|
||||
raise ValueError("proj_mat4 is None")
|
||||
|
||||
uv = np.argwhere(self.missing).astype(self.points.dtype)
|
||||
uv[:, 0] /= (self.missing.shape[1] - 1) / 2
|
||||
uv[:, 1] /= (self.missing.shape[0] - 1) / 2
|
||||
uv -= 1
|
||||
uv = np.stack((
|
||||
uv[:, 1],
|
||||
-uv[:, 0],
|
||||
np.ones(uv.shape[0]), # far clipping plane
|
||||
np.ones(uv.shape[0]), # homogeneous coordinate
|
||||
), axis=-1)
|
||||
uv = uv @ (self.cam_mat4 @ np.linalg.inv(self.proj_mat4)).T
|
||||
|
||||
out = self.copy(deep=deep) if copy else self
|
||||
out.points[self.missing, :] = uv[:, :3] / uv[:, 3][:, None]
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def is_hitting(self) -> bool:
|
||||
return np.any(self.hits)
|
||||
|
||||
@lru_property
|
||||
def has_colors(self) -> bool:
|
||||
return not self.colors is None
|
||||
|
||||
@lru_property
|
||||
def has_normals(self) -> bool:
|
||||
return not self.normals is None
|
||||
|
||||
@lru_property
|
||||
def has_miss_distances(self) -> bool:
|
||||
return not self.distances is None
|
||||
|
||||
@lru_property
|
||||
def any_missing(self) -> bool:
|
||||
return np.any(self.missing)
|
||||
|
||||
@lru_property
|
||||
def has_missing(self) -> bool:
|
||||
return self.any_missing and not np.any(np.isnan(self.points[self.missing]))
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos
|
||||
else:
|
||||
cam_pos = self.cam_pos
|
||||
for _ in range(self.hits.ndim):
|
||||
cam_pos = cam_pos[None, ...]
|
||||
return cam_pos
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed_hit(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos[self.hits, :]
|
||||
else:
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def cam_pos_unsqueezed_miss(self) -> H5Array:
|
||||
if self.cam_pos.ndim != 1:
|
||||
return self.cam_pos[self.miss, :]
|
||||
else:
|
||||
return self.cam_pos[None, :]
|
||||
|
||||
@lru_property
|
||||
def ray_dirs(self) -> H5Array:
|
||||
return (self.points - self.cam_pos_unsqueezed) * (1 / self.depths[..., None])
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_hit(self) -> H5Array:
|
||||
out = self.points[self.hits, :] - self.cam_pos_unsqueezed_hit
|
||||
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def ray_dirs_miss(self) -> H5Array:
|
||||
out = self.points[self.miss, :] - self.cam_pos_unsqueezed_miss
|
||||
out /= np.linalg.norm(out, axis=-1)[..., None] # normalize
|
||||
return out
|
||||
|
||||
@lru_property
|
||||
def depths(self) -> H5Array:
|
||||
return np.linalg.norm(self.points - self.cam_pos_unsqueezed, axis=-1)
|
||||
|
||||
@lru_property
|
||||
def missing(self) -> H5Array:
|
||||
return ~(self.hits | self.miss)
|
||||
|
||||
@classmethod
|
||||
def from_mesh_single_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
|
||||
if "phi" not in kw and not "theta" in kw:
|
||||
kw["theta"], kw["phi"] = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
scan = sample_single_view_scan_from_mesh(mesh, **kw).to_uv_scan()
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
assert scan.is_2d
|
||||
return scan
|
||||
|
||||
@classmethod
|
||||
def from_mesh_sphere_view(cls, mesh: Trimesh, *, compute_miss_distances: bool = False, **kw) -> "SingleViewUVScan":
|
||||
scan = sample_sphere_view_scan_from_mesh(mesh, **kw)
|
||||
if compute_miss_distances:
|
||||
surface_points = None
|
||||
if scan.hits.sum() > mesh.vertices.shape[0]:
|
||||
surface_points = mesh.vertices.astype(scan.points.dtype)
|
||||
if not kw.get("no_unit_sphere", False):
|
||||
translation, scale = compute_unit_sphere_transform(mesh, dtype=scan.points.dtype)
|
||||
surface_points = (surface_points + translation) * scale
|
||||
scan.compute_miss_distances(surface_points=surface_points)
|
||||
assert scan.is_flat
|
||||
return scan
|
||||
|
||||
def flatten_and_permute_(self: _T, copy=False) -> _T: # inplace by default
|
||||
n_items = np.product(self.hits.shape)
|
||||
permutation = np.random.permutation(n_items)
|
||||
|
||||
out = self.copy(deep=False) if copy else self
|
||||
out.hits = out.hits .reshape((n_items, ))[permutation]
|
||||
out.miss = out.miss .reshape((n_items, ))[permutation]
|
||||
out.points = out.points .reshape((n_items, 3))[permutation, :]
|
||||
out.normals = out.normals .reshape((n_items, 3))[permutation, :] if out.has_normals else None
|
||||
out.colors = out.colors .reshape((n_items, 3))[permutation, :] if out.has_colors else None
|
||||
out.distances = out.distances.reshape((n_items, ))[permutation] if out.has_miss_distances else None
|
||||
return out
|
||||
|
||||
@property
|
||||
def is_single_view(self) -> bool:
|
||||
return np.product(self.cam_pos.shape[:-1]) == 1 if not self.cam_pos is None else True
|
||||
|
||||
@property
|
||||
def is_flat(self) -> bool:
|
||||
return len(self.hits.shape) == 1
|
||||
|
||||
@property
|
||||
def is_2d(self) -> bool:
|
||||
return len(self.hits.shape) == 2
|
||||
|
||||
|
||||
# transforms can be found in pytorch3d.transforms and in open3d
|
||||
# and in trimesh.transformations
|
||||
|
||||
def sample_single_view_scans_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
n_batches : int,
|
||||
scan_resolution : int = 400,
|
||||
compute_normals : bool = False,
|
||||
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
|
||||
camera_distance : float = 2,
|
||||
no_filter_backhits : bool = False,
|
||||
) -> Iterable[SingleViewScan]:
|
||||
|
||||
normalized_mesh_cache = []
|
||||
|
||||
for _ in range(n_batches):
|
||||
theta, phi = points.generate_random_sphere_points(1, compute_sphere_coordinates=True)[0]
|
||||
|
||||
yield sample_single_view_scan_from_mesh(
|
||||
mesh = mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
_mesh_is_normalized = False,
|
||||
scan_resolution = scan_resolution,
|
||||
compute_normals = compute_normals,
|
||||
fov = fov,
|
||||
camera_distance = camera_distance,
|
||||
no_filter_backhits = no_filter_backhits,
|
||||
_mesh_cache = normalized_mesh_cache,
|
||||
)
|
||||
|
||||
def sample_single_view_scan_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
scan_resolution : int = 200,
|
||||
compute_normals : bool = False,
|
||||
fov : float = 1.0472, # 60 degrees in radians, vertical field of view.
|
||||
camera_distance : float = 2,
|
||||
no_filter_backhits : bool = False,
|
||||
no_unit_sphere : bool = False,
|
||||
dtype : type = np.float32,
|
||||
_mesh_cache : Optional[list] = None, # provide a list if mesh is reused
|
||||
) -> SingleViewScan:
|
||||
|
||||
# scale and center to unit sphere
|
||||
is_cache = isinstance(_mesh_cache, list)
|
||||
if is_cache and _mesh_cache and _mesh_cache[0] is mesh:
|
||||
_, mesh, translation, scale = _mesh_cache
|
||||
else:
|
||||
if is_cache:
|
||||
if _mesh_cache:
|
||||
_mesh_cache.clear()
|
||||
_mesh_cache.append(mesh)
|
||||
translation, scale = compute_unit_sphere_transform(mesh)
|
||||
mesh = mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
if is_cache:
|
||||
_mesh_cache.extend((mesh, translation, scale))
|
||||
|
||||
z_near = 1
|
||||
z_far = 3
|
||||
cam_mat4 = sdf_scan.get_camera_transform_looking_at_origin(phi, theta, camera_distance=camera_distance)
|
||||
cam_pos = cam_mat4 @ np.array([0, 0, 0, 1])
|
||||
|
||||
scan = sdf_scan.Scan(mesh,
|
||||
camera_transform = cam_mat4,
|
||||
resolution = scan_resolution,
|
||||
calculate_normals = compute_normals,
|
||||
fov = fov,
|
||||
z_near = z_near,
|
||||
z_far = z_far,
|
||||
no_flip_backfaced_normals = True
|
||||
)
|
||||
|
||||
# all the scan rays that hit the far plane, based on sdf_scan.Scan.__init__
|
||||
misses = np.argwhere(scan.depth_buffer == 0)
|
||||
points_miss = np.ones((misses.shape[0], 4))
|
||||
points_miss[:, [1, 0]] = misses.astype(float) / (scan_resolution -1) * 2 - 1
|
||||
points_miss[:, 1] *= -1
|
||||
points_miss[:, 2] = 1 # far plane in clipping space
|
||||
points_miss = points_miss @ (cam_mat4 @ np.linalg.inv(scan.projection_matrix)).T
|
||||
points_miss /= points_miss[:, 3][:, np.newaxis]
|
||||
points_miss = points_miss[:, :3]
|
||||
|
||||
uv_hits = scan.depth_buffer != 0
|
||||
uv_miss = ~uv_hits
|
||||
|
||||
if not no_filter_backhits:
|
||||
if not compute_normals:
|
||||
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
|
||||
# inner product
|
||||
mask = np.einsum('ij,ij->i', scan.points - cam_pos[:3][None, :], scan.normals) < 0
|
||||
scan.points = scan.points [mask, :]
|
||||
scan.normals = scan.normals[mask, :]
|
||||
uv_hits[uv_hits] = mask
|
||||
|
||||
transforms = {}
|
||||
|
||||
# undo unit-sphere transform
|
||||
if no_unit_sphere:
|
||||
scan.points = scan.points * (1 / scale) - translation
|
||||
points_miss = points_miss * (1 / scale) - translation
|
||||
cam_pos[:3] = cam_pos[:3] * (1 / scale) - translation
|
||||
cam_mat4[:3, :] *= 1 / scale
|
||||
cam_mat4[:3, 3] -= translation
|
||||
|
||||
transforms["unit_sphere"] = T.scale_and_translate(scale=scale, translate=translation)
|
||||
transforms["model"] = np.eye(4)
|
||||
else:
|
||||
transforms["model"] = np.linalg.inv(T.scale_and_translate(scale=scale, translate=translation))
|
||||
transforms["unit_sphere"] = np.eye(4)
|
||||
|
||||
return SingleViewScan(
|
||||
normals_hit = scan.normals .astype(dtype),
|
||||
points_hit = scan.points .astype(dtype),
|
||||
points_miss = points_miss .astype(dtype),
|
||||
distances_miss = None,
|
||||
colors_hit = None,
|
||||
colors_miss = None,
|
||||
uv_hits = uv_hits .astype(bool),
|
||||
uv_miss = uv_miss .astype(bool),
|
||||
cam_pos = cam_pos[:3] .astype(dtype),
|
||||
cam_mat4 = cam_mat4 .astype(dtype),
|
||||
proj_mat4 = scan.projection_matrix .astype(dtype),
|
||||
transforms = {k:v.astype(dtype) for k, v in transforms.items()},
|
||||
)
|
||||
|
||||
def sample_sphere_view_scan_from_mesh(
|
||||
mesh : Trimesh,
|
||||
*,
|
||||
sphere_points : int = 4000, # resulting rays are n*(n-1)
|
||||
compute_normals : bool = False,
|
||||
no_filter_backhits : bool = False,
|
||||
no_unit_sphere : bool = False,
|
||||
no_permute : bool = False,
|
||||
dtype : type = np.float32,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
translation, scale = compute_unit_sphere_transform(mesh, dtype=dtype)
|
||||
|
||||
# get unit-sphere points, then transform to model space
|
||||
two_sphere = generate_equidistant_sphere_rays(sphere_points, **kw).astype(dtype) # (n*(n-1), 2, 3)
|
||||
two_sphere = two_sphere / scale - translation # we transform after cache lookup
|
||||
|
||||
if mesh.ray.__class__.__module__.split(".")[-1] != "ray_pyembree":
|
||||
warnings.warn("Pyembree not found, the ray-tracing will be SLOW!")
|
||||
|
||||
(
|
||||
locations,
|
||||
index_ray,
|
||||
index_tri,
|
||||
) = mesh.ray.intersects_location(
|
||||
two_sphere[:, 0, :],
|
||||
two_sphere[:, 1, :] - two_sphere[:, 0, :], # direction, not target coordinate
|
||||
multiple_hits=False,
|
||||
)
|
||||
|
||||
|
||||
if compute_normals:
|
||||
location_normals = mesh.face_normals[index_tri]
|
||||
|
||||
batch = two_sphere.shape[:1]
|
||||
hits = np.zeros((*batch,), dtype=np.bool)
|
||||
miss = np.ones((*batch,), dtype=np.bool)
|
||||
cam_pos = two_sphere[:, 0, :]
|
||||
intersections = two_sphere[:, 1, :] # far-plane, effectively
|
||||
normals = np.zeros((*batch, 3), dtype=dtype)
|
||||
|
||||
index_ray_front = index_ray
|
||||
if not no_filter_backhits:
|
||||
if not compute_normals:
|
||||
raise ValueError("not `no_filter_backhits` requires `compute_normals`")
|
||||
mask = ((intersections[index_ray] - cam_pos[index_ray]) * location_normals).sum(axis=-1) <= 0
|
||||
index_ray_front = index_ray[mask]
|
||||
|
||||
|
||||
hits[index_ray_front] = True
|
||||
miss[index_ray] = False
|
||||
intersections[index_ray] = locations
|
||||
normals[index_ray] = location_normals
|
||||
|
||||
|
||||
if not no_permute:
|
||||
assert len(batch) == 1, batch
|
||||
permutation = np.random.permutation(*batch)
|
||||
hits = hits [permutation]
|
||||
miss = miss [permutation]
|
||||
intersections = intersections[permutation, :]
|
||||
normals = normals [permutation, :]
|
||||
cam_pos = cam_pos [permutation, :]
|
||||
|
||||
# apply unit sphere transform
|
||||
if not no_unit_sphere:
|
||||
intersections = (intersections + translation) * scale
|
||||
cam_pos = (cam_pos + translation) * scale
|
||||
|
||||
return SingleViewUVScan(
|
||||
hits = hits,
|
||||
miss = miss,
|
||||
points = intersections,
|
||||
normals = normals,
|
||||
colors = None, # colors
|
||||
distances = None,
|
||||
cam_pos = cam_pos,
|
||||
cam_mat4 = None,
|
||||
proj_mat4 = None,
|
||||
transforms = {},
|
||||
)
|
||||
|
||||
def distance_from_rays_to_point_cloud(
|
||||
ray_origins : np.ndarray, # (*A, 3)
|
||||
ray_dirs : np.ndarray, # (*A, 3)
|
||||
points : np.ndarray, # (*B, 3)
|
||||
dirs_normalized : bool = False,
|
||||
n_steps : int = 40,
|
||||
) -> np.ndarray: # (A)
|
||||
|
||||
# anything outside of this volume will never constribute to the result
|
||||
max_norm = max(
|
||||
np.linalg.norm(ray_origins, axis=-1).max(),
|
||||
np.linalg.norm(points, axis=-1).max(),
|
||||
) * 1.02
|
||||
|
||||
if not dirs_normalized:
|
||||
ray_dirs = ray_dirs / np.linalg.norm(ray_dirs, axis=-1)[..., None]
|
||||
|
||||
|
||||
# deal with single-view clouds
|
||||
if ray_origins.shape != ray_dirs.shape:
|
||||
ray_origins = np.broadcast_to(ray_origins, ray_dirs.shape)
|
||||
|
||||
n_points = np.product(points.shape[:-1])
|
||||
use_faiss = n_points > 160000*4
|
||||
if not use_faiss:
|
||||
index = BallTree(points)
|
||||
else:
|
||||
# http://ann-benchmarks.com/index.html
|
||||
assert np.issubdtype(points.dtype, np.float32)
|
||||
assert np.issubdtype(ray_origins.dtype, np.float32)
|
||||
assert np.issubdtype(ray_dirs.dtype, np.float32)
|
||||
index = faiss.index_factory(points.shape[-1], "NSG32,Flat") # https://github.com/facebookresearch/faiss/wiki/The-index-factory
|
||||
|
||||
index.nprobe = 5 # 10 # default is 1
|
||||
index.train(points)
|
||||
index.add(points)
|
||||
|
||||
if not use_faiss:
|
||||
min_d, min_n = index.query(ray_origins, k=1, return_distance=True)
|
||||
else:
|
||||
min_d, min_n = index.search(ray_origins, k=1)
|
||||
min_d = np.sqrt(min_d)
|
||||
acc_d = min_d.copy()
|
||||
|
||||
for step in range(1, n_steps+1):
|
||||
query_points = ray_origins + acc_d * ray_dirs
|
||||
if max_norm is not None:
|
||||
qmask = np.linalg.norm(query_points, axis=-1) < max_norm
|
||||
if not qmask.any(): break
|
||||
query_points = query_points[qmask]
|
||||
else:
|
||||
qmask = slice(None)
|
||||
if not use_faiss:
|
||||
current_d, current_n = index.query(query_points, k=1, return_distance=True)
|
||||
else:
|
||||
current_d, current_n = index.search(query_points, k=1)
|
||||
current_d = np.sqrt(current_d)
|
||||
if max_norm is not None:
|
||||
min_d[qmask] = np.minimum(current_d, min_d[qmask])
|
||||
new_min_mask = min_d[qmask] == current_d
|
||||
qmask2 = qmask.copy()
|
||||
qmask2[qmask2] = new_min_mask[..., 0]
|
||||
min_n[qmask2] = current_n[new_min_mask[..., 0]]
|
||||
acc_d[qmask] += current_d * 0.25
|
||||
else:
|
||||
np.minimum(current_d, min_d, out=min_d)
|
||||
new_min_mask = min_d == current_d
|
||||
min_n[new_min_mask] = current_n[new_min_mask]
|
||||
acc_d += current_d * 0.25
|
||||
|
||||
closest_points = points[min_n[:, 0], :] # k=1
|
||||
distances = np.linalg.norm(np.cross(closest_points - ray_origins, ray_dirs, axis=-1), axis=-1)
|
||||
return distances
|
||||
|
||||
# helpers
|
||||
|
||||
@compose(np.array) # make copy to avoid lru cache mutation
|
||||
@lru_cache(maxsize=1)
|
||||
def generate_equidistant_sphere_rays(n : int, **kw) -> np.ndarray: # output (n*n(-1)) rays, n may be off
|
||||
sphere_points = points.generate_equidistant_sphere_points(n=n, **kw)
|
||||
|
||||
indices = np.indices((len(sphere_points),))[0] # (N)
|
||||
# cartesian product
|
||||
cprod = np.transpose([np.tile(indices, len(indices)), np.repeat(indices, len(indices))]) # (N**2, 2)
|
||||
# filter repeated combinations
|
||||
permutations = cprod[cprod[:, 0] != cprod[:, 1], :] # (N*(N-1), 2)
|
||||
# lookup sphere points
|
||||
two_sphere = sphere_points[permutations, :] # (N*(N-1), 2, 3)
|
||||
|
||||
return two_sphere
|
||||
|
||||
def compute_unit_sphere_transform(mesh: Trimesh, *, dtype=type) -> tuple[np.ndarray, float]:
|
||||
"""
|
||||
returns translation and scale which mesh_to_sdf applies to meshes before computing their SDF cloud
|
||||
"""
|
||||
# the transformation applied by mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
translation = -mesh.bounding_box.centroid
|
||||
scale = 1 / np.max(np.linalg.norm(mesh.vertices + translation, axis=1))
|
||||
if dtype is not None:
|
||||
translation = translation.astype(dtype)
|
||||
scale = scale .astype(dtype)
|
||||
return translation, scale
|
||||
6
ifield/data/common/types.py
Normal file
6
ifield/data/common/types.py
Normal file
@@ -0,0 +1,6 @@
|
||||
__doc__ = """
|
||||
Some helper types.
|
||||
"""
|
||||
|
||||
class MalformedMesh(Exception):
|
||||
pass
|
||||
28
ifield/data/config.py
Normal file
28
ifield/data/config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from ..utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
import os
|
||||
import warnings
|
||||
|
||||
|
||||
def data_path_get(dataset_name: str, no_warn: bool = False) -> Path:
|
||||
dataset_envvar = f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"
|
||||
if dataset_envvar in os.environ:
|
||||
data_path = Path(os.environ[dataset_envvar])
|
||||
elif "IFIELD_DATA_MODELS" in os.environ:
|
||||
data_path = Path(os.environ["IFIELD_DATA_MODELS"]) / dataset_name
|
||||
else:
|
||||
data_path = Path(__file__).resolve().parent.parent.parent / "data" / "models" / dataset_name
|
||||
if not data_path.is_dir() and not no_warn:
|
||||
warnings.warn(f"{make_relative(data_path, Path.cwd()).__str__()!r} is not a directory!")
|
||||
return data_path
|
||||
|
||||
def data_path_persist(dataset_name: Optional[str], path: os.PathLike) -> os.PathLike:
|
||||
"Persist the datapath, ensuring subprocesses also will use it. The path passes through."
|
||||
|
||||
if dataset_name is None:
|
||||
os.environ["IFIELD_DATA_MODELS"] = str(path)
|
||||
else:
|
||||
os.environ[f"IFIELD_DATA_MODELS_{dataset_name.replace(*'-_').upper()}"] = str(path)
|
||||
|
||||
return path
|
||||
56
ifield/data/coseg/__init__.py
Normal file
56
ifield/data/coseg/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from ..config import data_path_get, data_path_persist
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
|
||||
# Data source:
|
||||
# http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
|
||||
|
||||
__ALL__ = ["config", "Model", "MODELS"]
|
||||
|
||||
Archive = namedtuple("Archive", "url fname download_size_str")
|
||||
|
||||
@(lambda x: x()) # singleton
|
||||
class config:
|
||||
DATA_PATH = property(
|
||||
doc = """
|
||||
Path to the dataset. The following envvars override it:
|
||||
${IFIELD_DATA_MODELS}/coseg
|
||||
${IFIELD_DATA_MODELS_COSEG}
|
||||
""",
|
||||
fget = lambda self: data_path_get ("coseg"),
|
||||
fset = lambda self, path: data_path_persist("coseg", path),
|
||||
)
|
||||
|
||||
@property
|
||||
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
|
||||
return [
|
||||
self.DATA_PATH / "downloaded.json",
|
||||
]
|
||||
|
||||
SHAPES: dict[str, Archive] = {
|
||||
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/shapes.zip", "candelabra-shapes.zip", "3,3M"),
|
||||
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/shapes.zip", "chair-shapes.zip", "3,2M"),
|
||||
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/shapes.zip", "four-legged-shapes.zip", "2,9M"),
|
||||
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/shapes.zip", "goblets-shapes.zip", "500K"),
|
||||
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/shapes.zip", "guitars-shapes.zip", "1,9M"),
|
||||
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/shapes.zip", "lampes-shapes.zip", "2,4M"),
|
||||
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/shapes.zip", "vases-shapes.zip", "5,5M"),
|
||||
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/shapes.zip", "irons-shapes.zip", "1,2M"),
|
||||
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/shapes.zip", "tele-aliens-shapes.zip", "15M"),
|
||||
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/shapes.zip", "large-vases-shapes.zip", "6,2M"),
|
||||
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/shapes.zip", "large-chairs-shapes.zip", "14M"),
|
||||
}
|
||||
GROUND_TRUTHS: dict[str, Archive] = {
|
||||
"candelabra" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Candelabra/gt.zip", "candelabra-gt.zip", "68K"),
|
||||
"chair" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Chair/gt.zip", "chair-gt.zip", "20K"),
|
||||
"four-legged" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Four-legged/gt.zip", "four-legged-gt.zip", "24K"),
|
||||
"goblets" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Goblets/gt.zip", "goblets-gt.zip", "4,0K"),
|
||||
"guitars" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Guitars/gt.zip", "guitars-gt.zip", "12K"),
|
||||
"lampes" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Lampes/gt.zip", "lampes-gt.zip", "60K"),
|
||||
"vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Vases/gt.zip", "vases-gt.zip", "40K"),
|
||||
"irons" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Irons/gt.zip", "irons-gt.zip", "8,0K"),
|
||||
"tele-aliens" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Tele-aliens/gt.zip", "tele-aliens-gt.zip", "72K"),
|
||||
"large-vases" : Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Vases/gt.zip", "large-vases-gt.zip", "68K"),
|
||||
"large-chairs": Archive("http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/data/Large-Chairs/gt.zip", "large-chairs-gt.zip", "116K"),
|
||||
}
|
||||
135
ifield/data/coseg/download.py
Normal file
135
ifield/data/coseg/download.py
Normal file
@@ -0,0 +1,135 @@
|
||||
#!/usr/bin/env python3
|
||||
from . import config
|
||||
from ...utils.helpers import make_relative
|
||||
from ..common import download
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
import argparse
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
|
||||
|
||||
def is_downloaded(*a, **kw):
|
||||
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
|
||||
|
||||
def download_and_extract(target_dir: Path, url_dict: dict[str, str], *, force=False, silent=False) -> bool:
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ret = False
|
||||
for url, fname in url_dict.items():
|
||||
if not force:
|
||||
if is_downloaded(target_dir, url): continue
|
||||
if not download.check_url(url):
|
||||
print("ERROR:", url)
|
||||
continue
|
||||
ret = True
|
||||
|
||||
if force or not (target_dir / "archives" / fname).is_file():
|
||||
|
||||
data = download.download_data(url, silent=silent, label=fname)
|
||||
assert url.endswith(".zip")
|
||||
|
||||
print("writing...")
|
||||
|
||||
(target_dir / "archives").mkdir(parents=True, exist_ok=True)
|
||||
with (target_dir / "archives" / fname).open("wb") as f:
|
||||
f.write(data)
|
||||
del data
|
||||
|
||||
print(f"extracting {fname}...")
|
||||
|
||||
with zipfile.ZipFile(target_dir / "archives" / fname, 'r') as f:
|
||||
f.extractall(target_dir / Path(fname).stem.removesuffix("-shapes").removesuffix("-gt"))
|
||||
|
||||
is_downloaded(target_dir, url, add=True)
|
||||
|
||||
return ret
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Download The COSEG Shape Dataset.
|
||||
More info: http://irc.cs.sdu.edu.cn/~yunhai/public_html/ssl/ssd.htm
|
||||
|
||||
Example:
|
||||
|
||||
download-coseg --shapes chairs
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument
|
||||
|
||||
arg("sets", nargs="*", default=[],
|
||||
help="Which set to download, defaults to none.")
|
||||
arg("--all", action="store_true",
|
||||
help="Download all sets")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
|
||||
arg("--shapes", action="store_true",
|
||||
help="Download the 3d shapes for each chosen set")
|
||||
arg("--gts", action="store_true",
|
||||
help="Download the ground-truth segmentation data for each chosen set")
|
||||
|
||||
arg("--list", action="store_true",
|
||||
help="Lists all the sets")
|
||||
arg("--list-urls", action="store_true",
|
||||
help="Lists the urls to download")
|
||||
arg("--list-sizes", action="store_true",
|
||||
help="Lists the download size of each set")
|
||||
arg("--silent", action="store_true",
|
||||
help="")
|
||||
arg("--force", action="store_true",
|
||||
help="Download again even if already downloaded")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser=make_parser()):
|
||||
args = parser.parse_args()
|
||||
|
||||
assert set(config.SHAPES.keys()) == set(config.GROUND_TRUTHS.keys())
|
||||
|
||||
set_names = sorted(set(args.sets))
|
||||
if args.all:
|
||||
assert not set_names, "--all is mutually exclusive from manually selected sets"
|
||||
set_names = sorted(config.SHAPES.keys())
|
||||
|
||||
if args.list:
|
||||
print(*config.SHAPES.keys(), sep="\n")
|
||||
exit()
|
||||
|
||||
if args.list_sizes:
|
||||
print(*(f"{set_name:<15}{config.SHAPES[set_name].download_size_str}" for set_name in (set_names or config.SHAPES.keys())), sep="\n")
|
||||
exit()
|
||||
|
||||
try:
|
||||
url_dict \
|
||||
= {config.SHAPES[set_name].url : config.SHAPES[set_name].fname for set_name in set_names if args.shapes} \
|
||||
| {config.GROUND_TRUTHS[set_name].url : config.GROUND_TRUTHS[set_name].fname for set_name in set_names if args.gts}
|
||||
except KeyError:
|
||||
print("Error: unrecognized object name:", *set(set_names).difference(config.SHAPES.keys()), sep="\n")
|
||||
exit(1)
|
||||
|
||||
if not url_dict:
|
||||
if set_names and not (args.shapes or args.gts):
|
||||
print("Error: Provide at least one of --shapes of --gts")
|
||||
else:
|
||||
print("Error: No object set was selected for download!")
|
||||
exit(1)
|
||||
|
||||
if args.list_urls:
|
||||
print(*url_dict.keys(), sep="\n")
|
||||
exit()
|
||||
|
||||
print("Download start")
|
||||
any_downloaded = download_and_extract(
|
||||
target_dir = Path(args.dir),
|
||||
url_dict = url_dict,
|
||||
force = args.force,
|
||||
silent = args.silent,
|
||||
)
|
||||
if not any_downloaded:
|
||||
print("Everything has already been downloaded, skipping.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
137
ifield/data/coseg/preprocess.py
Normal file
137
ifield/data/coseg/preprocess.py
Normal file
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
from . import config, read
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
import argparse
|
||||
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Preprocess the COSEG dataset. Depends on `download-coseg --shapes ...` having been run.
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument # brevity
|
||||
|
||||
arg("items", nargs="*", default=[],
|
||||
help="Which object-set[/model-id] to process, defaults to all downloaded. Format: OBJECT-SET[/MODEL-ID]")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
arg("--force", action="store_true",
|
||||
help="Overwrite existing files")
|
||||
arg("--list-models", action="store_true",
|
||||
help="List the downloaded models available for preprocessing")
|
||||
arg("--list-object-sets", action="store_true",
|
||||
help="List the downloaded object-sets available for preprocessing")
|
||||
arg("--list-pages", type=int, default=None,
|
||||
help="List the downloaded models available for preprocessing, paginated into N pages.")
|
||||
arg("--page", nargs=2, type=int, default=[0, 1],
|
||||
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
|
||||
|
||||
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
|
||||
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
|
||||
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
|
||||
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sphere-scan", action="store_true",
|
||||
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
|
||||
|
||||
arg3 = parser.add_argument_group("modifiers").add_argument # brevity
|
||||
arg3("--n-sphere-points", type=int, default=4000,
|
||||
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
|
||||
arg3("--compute-miss-distances", action="store_true",
|
||||
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
|
||||
arg3("--fill-missing-uv-points", action="store_true",
|
||||
help="TODO")
|
||||
arg3("--no-filter-backhits", action="store_true",
|
||||
help="Do not filter scan hits on backside of mesh faces.")
|
||||
arg3("--no-unit-sphere", action="store_true",
|
||||
help="Do not center the objects to the unit sphere.")
|
||||
arg3("--convert-ok", action="store_true",
|
||||
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
|
||||
arg3("--debug", action="store_true",
|
||||
help="Abort on failiure.")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser=make_parser()):
|
||||
args = parser.parse_args()
|
||||
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list_models or args.list_object_sets or args.list_pages):
|
||||
parser.error("no preprocessing target selected") # exits
|
||||
|
||||
config.DATA_PATH = Path(args.dir)
|
||||
|
||||
object_sets = [i for i in args.items if "/" not in i]
|
||||
models = [i.split("/") for i in args.items if "/" in i]
|
||||
|
||||
# convert/expand synsets to models
|
||||
# they are mutually exclusive
|
||||
if object_sets: assert not models
|
||||
if models: assert not object_sets
|
||||
if not models:
|
||||
models = read.list_model_ids(tuple(object_sets) or None)
|
||||
|
||||
if args.list_models:
|
||||
try:
|
||||
print(*(f"{object_set_id}/{model_id}" for object_set_id, model_id in models), sep="\n")
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
parser.exit()
|
||||
|
||||
if args.list_object_sets:
|
||||
try:
|
||||
print(*sorted(set(object_set_id for object_set_id, model_id in models)), sep="\n")
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
parser.exit()
|
||||
|
||||
if args.list_pages is not None:
|
||||
try:
|
||||
print(*(
|
||||
f"--page {i} {args.list_pages} {object_set_id}/{model_id}"
|
||||
for object_set_id, model_id in models
|
||||
for i in range(args.list_pages)
|
||||
), sep="\n")
|
||||
except BrokenPipeError:
|
||||
pass
|
||||
parser.exit()
|
||||
|
||||
if args.precompute_mesh_sv_scan_clouds:
|
||||
read.precompute_mesh_scan_point_clouds(
|
||||
models,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sv_scan_uvs:
|
||||
read.precompute_mesh_scan_uvs(
|
||||
models,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
fill_missing_points = args.fill_missing_uv_points,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sphere_scan:
|
||||
read.precompute_mesh_sphere_scan(
|
||||
models,
|
||||
sphere_points = args.n_sphere_points,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
290
ifield/data/coseg/read.py
Normal file
290
ifield/data/coseg/read.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from . import config
|
||||
from ..common import points
|
||||
from ..common import processing
|
||||
from ..common.scan import SingleViewScan, SingleViewUVScan
|
||||
from ..common.types import MalformedMesh
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Iterable
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
__doc__ = """
|
||||
Here are functions for reading and preprocessing coseg benchmark data
|
||||
|
||||
There are essentially a few sets per object:
|
||||
"img" - meaning the RGBD images (none found in coseg)
|
||||
"mesh_scans" - meaning synthetic scans of a mesh
|
||||
"""
|
||||
|
||||
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0)) # rotate to be upright in pyrender
|
||||
MESH_POSE_CORRECTIONS = { # to gain a shared canonical orientation
|
||||
("four-legged", 381): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 382): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 383): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 384): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 385): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 386): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 387): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 388): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 389): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 390): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 391): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 392): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 393): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 394): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 395): T.rotation_matrix(-0.2*np.pi/2, (0, 1, 0))@T.rotation_matrix(1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 396): T.rotation_matrix( 1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 397): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 398): T.rotation_matrix( -1*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 399): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
("four-legged", 400): T.rotation_matrix( 0*np.pi/2, (0, 0, 1)),
|
||||
}
|
||||
|
||||
|
||||
ModelUid = tuple[str, int]
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_object_sets() -> list[str]:
|
||||
return sorted(
|
||||
object_set.name
|
||||
for object_set in config.DATA_PATH.iterdir()
|
||||
if (object_set / "shapes").is_dir() and object_set.name != "archive"
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_model_ids(object_sets: Optional[tuple[str]] = None) -> list[ModelUid]:
|
||||
return sorted(
|
||||
(object_set.name, int(model.stem))
|
||||
for object_set in config.DATA_PATH.iterdir()
|
||||
if (object_set / "shapes").is_dir() and object_set.name != "archive" and (object_sets is None or object_set.name in object_sets)
|
||||
for model in (object_set / "shapes").iterdir()
|
||||
if model.is_file() and model.suffix == ".off"
|
||||
)
|
||||
|
||||
def list_model_id_strings(object_sets: Optional[tuple[str]] = None) -> list[str]:
|
||||
return [model_uid_to_string(object_set_id, model_id) for object_set_id, model_id in list_model_ids(object_sets)]
|
||||
|
||||
def model_uid_to_string(object_set_id: str, model_id: int) -> str:
|
||||
return f"{object_set_id}-{model_id}"
|
||||
|
||||
def model_id_string_to_uid(model_string_uid: str) -> ModelUid:
|
||||
object_set, split, model = model_string_uid.rpartition("-")
|
||||
assert split == "-"
|
||||
return (object_set, int(model))
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
|
||||
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)
|
||||
|
||||
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
|
||||
return (
|
||||
f"{'np'[theta>=0]}{abs(theta):.2f}"
|
||||
f"{'np'[phi >=0]}{abs(phi) :.2f}"
|
||||
).replace(".", "d")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
|
||||
out = [
|
||||
mesh_scan_identifier(phi=phi, theta=theta)
|
||||
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
|
||||
]
|
||||
assert len(out) == len(set(out))
|
||||
return out
|
||||
|
||||
# ===
|
||||
|
||||
def read_mesh(object_set_id: str, model_id: int) -> trimesh.Trimesh:
|
||||
path = config.DATA_PATH / object_set_id / "shapes" / f"{model_id}.off"
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"{path = }")
|
||||
try:
|
||||
mesh = trimesh.load(path, force="mesh")
|
||||
except Exception as e:
|
||||
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
|
||||
|
||||
pose = MESH_POSE_CORRECTIONS.get((object_set_id, int(model_id)))
|
||||
mesh.apply_transform(pose @ MESH_TRANSFORM_SKYWARD if pose is not None else MESH_TRANSFORM_SKYWARD)
|
||||
return mesh
|
||||
|
||||
# === single-view scan clouds
|
||||
|
||||
def compute_mesh_scan_point_cloud(
|
||||
object_set_id : str,
|
||||
model_id : int,
|
||||
phi : float,
|
||||
theta : float,
|
||||
*,
|
||||
compute_miss_distances : bool = False,
|
||||
fill_missing_points : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False,
|
||||
**kw,
|
||||
) -> SingleViewScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_uv(object_set_id, model_id, phi=phi, theta=theta).to_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(object_set_id, model_id)
|
||||
scan = SingleViewScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
if fill_missing_points:
|
||||
scan.fill_missing_points()
|
||||
|
||||
return scan
|
||||
|
||||
def precompute_mesh_scan_point_clouds(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_point_cloud_h5_fnames(models, pose_identifiers, n_poses=n_poses)
|
||||
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
|
||||
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
|
||||
pretty_identifiers = [
|
||||
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for object_set_id, model_id in models
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
def computer(pretty_identifier: str) -> SingleViewScan:
|
||||
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_point_cloud(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_point_cloud(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
return SingleViewScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_point_cloud_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
for object_set_id, model_id in models
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
|
||||
# === single-view UV scan clouds
|
||||
|
||||
def compute_mesh_scan_uv(
|
||||
object_set_id : str,
|
||||
model_id : int,
|
||||
phi : float,
|
||||
theta : float,
|
||||
*,
|
||||
compute_miss_distances : bool = False,
|
||||
fill_missing_points : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_point_cloud(object_set_id, model_id, phi=phi, theta=theta).to_uv_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(object_set_id, model_id)
|
||||
scan = SingleViewUVScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
if fill_missing_points:
|
||||
scan.fill_missing_points()
|
||||
|
||||
return scan
|
||||
|
||||
def precompute_mesh_scan_uvs(models: Iterable[ModelUid], *, n_poses: int = 50, page: tuple[int, int] = (0, 1), force = False, debug = False, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses=n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses=n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_uv_h5_fnames(models, pose_identifiers, n_poses=n_poses)
|
||||
mlen_syn = max(len(object_set_id) for object_set_id, model_id in models)
|
||||
mlen_mod = max(len(str(model_id)) for object_set_id, model_id in models)
|
||||
pretty_identifiers = [
|
||||
f"{object_set_id.ljust(mlen_syn)} @ {str(model_id).ljust(mlen_mod)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for object_set_id, model_id in models
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
def computer(pretty_identifier: str) -> SingleViewUVScan:
|
||||
object_set_id, model_id, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_uv(object_set_id, int(model_id), phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_uv(object_set_id: str, model_id: int, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_uv_h5_fnames(models: Iterable[ModelUid], identifiers: Optional[Iterable[str]] = None, **kw):
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / object_set_id / "uv_scan_clouds" / f"{model_id}_normalized_{identifier}.h5"
|
||||
for object_set_id, model_id in models
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
|
||||
# === sphere-view (UV) scan clouds
|
||||
|
||||
def compute_mesh_sphere_scan(
|
||||
object_set_id : str,
|
||||
model_id : int,
|
||||
*,
|
||||
compute_normals : bool = True,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
mesh = read_mesh(object_set_id, model_id)
|
||||
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
return scan
|
||||
|
||||
def precompute_mesh_sphere_scan(models: Iterable[ModelUid], *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
|
||||
"precomputes all sphere scan clouds and stores them as HDF5 datasets"
|
||||
paths = list_mesh_sphere_scan_h5_fnames(models)
|
||||
identifiers = [model_uid_to_string(*i) for i in models]
|
||||
def computer(identifier: str) -> SingleViewScan:
|
||||
object_set_id, model_id = model_id_string_to_uid(identifier)
|
||||
return compute_mesh_sphere_scan(object_set_id, model_id, **kw)
|
||||
return processing.precompute_data(computer, identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_mesh_sphere_scan(object_set_id: str, model_id: int) -> SingleViewUVScan:
|
||||
file = config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_sphere_scan_h5_fnames(models: Iterable[ModelUid]) -> list[str]:
|
||||
return [
|
||||
config.DATA_PATH / object_set_id / "sphere_scan_clouds" / f"{model_id}_normalized.h5"
|
||||
for object_set_id, model_id in models
|
||||
]
|
||||
76
ifield/data/stanford/__init__.py
Normal file
76
ifield/data/stanford/__init__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from ..config import data_path_get, data_path_persist
|
||||
from collections import namedtuple
|
||||
import os
|
||||
|
||||
|
||||
# Data source:
|
||||
# http://graphics.stanford.edu/data/3Dscanrep/
|
||||
|
||||
__ALL__ = ["config", "Model", "MODELS"]
|
||||
|
||||
@(lambda x: x()) # singleton
|
||||
class config:
|
||||
DATA_PATH = property(
|
||||
doc = """
|
||||
Path to the dataset. The following envvars override it:
|
||||
${IFIELD_DATA_MODELS}/stanford
|
||||
${IFIELD_DATA_MODELS_STANFORD}
|
||||
""",
|
||||
fget = lambda self: data_path_get ("stanford"),
|
||||
fset = lambda self, path: data_path_persist("stanford", path),
|
||||
)
|
||||
|
||||
@property
|
||||
def IS_DOWNLOADED_DB(self) -> list[os.PathLike]:
|
||||
return [
|
||||
self.DATA_PATH / "downloaded.json",
|
||||
]
|
||||
|
||||
Model = namedtuple("Model", "url mesh_fname download_size_str")
|
||||
MODELS: dict[str, Model] = {
|
||||
"bunny": Model(
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/bunny.tar.gz",
|
||||
"bunny/reconstruction/bun_zipper.ply",
|
||||
"4.89M",
|
||||
),
|
||||
"drill_bit": Model(
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/drill.tar.gz",
|
||||
"drill/reconstruction/drill_shaft_vrip.ply",
|
||||
"555k",
|
||||
),
|
||||
"happy_buddha": Model(
|
||||
# religious symbol
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/happy/happy_recon.tar.gz",
|
||||
"happy_recon/happy_vrip.ply",
|
||||
"14.5M",
|
||||
),
|
||||
"dragon": Model(
|
||||
# symbol of Chinese culture
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/dragon/dragon_recon.tar.gz",
|
||||
"dragon_recon/dragon_vrip.ply",
|
||||
"11.2M",
|
||||
),
|
||||
"armadillo": Model(
|
||||
"http://graphics.stanford.edu/pub/3Dscanrep/armadillo/Armadillo.ply.gz",
|
||||
"armadillo.ply.gz",
|
||||
"3.87M",
|
||||
),
|
||||
"lucy": Model(
|
||||
# Christian angel
|
||||
"http://graphics.stanford.edu/data/3Dscanrep/lucy.tar.gz",
|
||||
"lucy.ply",
|
||||
"322M",
|
||||
),
|
||||
"asian_dragon": Model(
|
||||
# symbol of Chinese culture
|
||||
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_dragon.ply.gz",
|
||||
"xyzrgb_dragon.ply.gz",
|
||||
"70.5M",
|
||||
),
|
||||
"thai_statue": Model(
|
||||
# Hindu religious significance
|
||||
"http://graphics.stanford.edu/data/3Dscanrep/xyzrgb/xyzrgb_statuette.ply.gz",
|
||||
"xyzrgb_statuette.ply.gz",
|
||||
"106M",
|
||||
),
|
||||
}
|
||||
129
ifield/data/stanford/download.py
Normal file
129
ifield/data/stanford/download.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
from . import config
|
||||
from ...utils.helpers import make_relative
|
||||
from ..common import download
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Iterable
|
||||
import argparse
|
||||
import io
|
||||
import tarfile
|
||||
|
||||
|
||||
def is_downloaded(*a, **kw):
|
||||
return download.is_downloaded(*a, dbfiles=config.IS_DOWNLOADED_DB, **kw)
|
||||
|
||||
def download_and_extract(target_dir: Path, url_list: Iterable[str], *, force=False, silent=False) -> bool:
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ret = False
|
||||
for url in url_list:
|
||||
if not force:
|
||||
if is_downloaded(target_dir, url): continue
|
||||
if not download.check_url(url):
|
||||
print("ERROR:", url)
|
||||
continue
|
||||
ret = True
|
||||
|
||||
data = download.download_data(url, silent=silent, label=str(Path(url).name))
|
||||
|
||||
print("extracting...")
|
||||
if url.endswith(".ply.gz"):
|
||||
fname = target_dir / "meshes" / url.split("/")[-1].lower()
|
||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
with fname.open("wb") as f:
|
||||
f.write(data)
|
||||
elif url.endswith(".tar.gz"):
|
||||
with tarfile.open(fileobj=io.BytesIO(data)) as tar:
|
||||
for member in tar.getmembers():
|
||||
if not member.isfile(): continue
|
||||
if member.name.startswith("/"): continue
|
||||
if member.name.startswith("."): continue
|
||||
if Path(member.name).name.startswith("."): continue
|
||||
tar.extract(member, target_dir / "meshes")
|
||||
del tar
|
||||
else:
|
||||
raise NotImplementedError(f"Extraction for {str(Path(url).name)} unknown")
|
||||
|
||||
is_downloaded(target_dir, url, add=True)
|
||||
del data
|
||||
|
||||
return ret
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Download The Stanford 3D Scanning Repository models.
|
||||
More info: http://graphics.stanford.edu/data/3Dscanrep/
|
||||
|
||||
Example:
|
||||
|
||||
download-stanford bunny
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument
|
||||
|
||||
arg("objects", nargs="*", default=[],
|
||||
help="Which objects to download, defaults to none.")
|
||||
arg("--all", action="store_true",
|
||||
help="Download all objects")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
|
||||
arg("--list", action="store_true",
|
||||
help="Lists all the objects")
|
||||
arg("--list-urls", action="store_true",
|
||||
help="Lists the urls to download")
|
||||
arg("--list-sizes", action="store_true",
|
||||
help="Lists the download size of each model")
|
||||
arg("--silent", action="store_true",
|
||||
help="")
|
||||
arg("--force", action="store_true",
|
||||
help="Download again even if already downloaded")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser=make_parser()):
|
||||
args = parser.parse_args()
|
||||
|
||||
obj_names = sorted(set(args.objects))
|
||||
if args.all:
|
||||
assert not obj_names
|
||||
obj_names = sorted(config.MODELS.keys())
|
||||
if not obj_names and args.list_urls: config.MODELS.keys()
|
||||
|
||||
if args.list:
|
||||
print(*config.MODELS.keys(), sep="\n")
|
||||
exit()
|
||||
|
||||
if args.list_sizes:
|
||||
print(*(f"{obj_name:<15}{config.MODELS[obj_name].download_size_str}" for obj_name in (obj_names or config.MODELS.keys())), sep="\n")
|
||||
exit()
|
||||
|
||||
try:
|
||||
url_list = [config.MODELS[obj_name].url for obj_name in obj_names]
|
||||
except KeyError:
|
||||
print("Error: unrecognized object name:", *set(obj_names).difference(config.MODELS.keys()), sep="\n")
|
||||
exit(1)
|
||||
|
||||
if not url_list:
|
||||
print("Error: No object set was selected for download!")
|
||||
exit(1)
|
||||
|
||||
if args.list_urls:
|
||||
print(*url_list, sep="\n")
|
||||
exit()
|
||||
|
||||
|
||||
print("Download start")
|
||||
any_downloaded = download_and_extract(
|
||||
target_dir = Path(args.dir),
|
||||
url_list = url_list,
|
||||
force = args.force,
|
||||
silent = args.silent,
|
||||
)
|
||||
if not any_downloaded:
|
||||
print("Everything has already been downloaded, skipping.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
118
ifield/data/stanford/preprocess.py
Normal file
118
ifield/data/stanford/preprocess.py
Normal file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
import os; os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
|
||||
from . import config, read
|
||||
from ...utils.helpers import make_relative
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
import argparse
|
||||
|
||||
|
||||
|
||||
def make_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dedent("""
|
||||
Preprocess the Stanford models. Depends on `download-stanford` having been run.
|
||||
"""), formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
arg = parser.add_argument # brevity
|
||||
|
||||
arg("objects", nargs="*", default=[],
|
||||
help="Which objects to process, defaults to all downloaded")
|
||||
arg("--dir", default=str(config.DATA_PATH),
|
||||
help=f"The target directory. Default is {make_relative(config.DATA_PATH, Path.cwd()).__str__()!r}")
|
||||
arg("--force", action="store_true",
|
||||
help="Overwrite existing files")
|
||||
arg("--list", action="store_true",
|
||||
help="List the downloaded models available for preprocessing")
|
||||
arg("--list-pages", type=int, default=None,
|
||||
help="List the downloaded models available for preprocessing, paginated into N pages.")
|
||||
arg("--page", nargs=2, type=int, default=[0, 1],
|
||||
help="Subset of parts to compute. Use to parallelize. (page, total), page is 0 indexed")
|
||||
|
||||
arg2 = parser.add_argument_group("preprocessing targets").add_argument # brevity
|
||||
arg2("--precompute-mesh-sv-scan-clouds", action="store_true",
|
||||
help="Compute single-view hit+miss point clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sv-scan-uvs", action="store_true",
|
||||
help="Compute single-view hit+miss UV clouds from 100 synthetic scans.")
|
||||
arg2("--precompute-mesh-sphere-scan", action="store_true",
|
||||
help="Compute a sphere-view hit+miss cloud cast from n to n unit sphere points.")
|
||||
|
||||
arg3 = parser.add_argument_group("ray-scan modifiers").add_argument # brevity
|
||||
arg3("--n-sphere-points", type=int, default=4000,
|
||||
help="The number of unit-sphere points to sample rays from. Final result: n*(n-1).")
|
||||
arg3("--compute-miss-distances", action="store_true",
|
||||
help="Compute the distance to the nearest hit for each miss in the hit+miss clouds.")
|
||||
arg3("--fill-missing-uv-points", action="store_true",
|
||||
help="TODO")
|
||||
arg3("--no-filter-backhits", action="store_true",
|
||||
help="Do not filter scan hits on backside of mesh faces.")
|
||||
arg3("--no-unit-sphere", action="store_true",
|
||||
help="Do not center the objects to the unit sphere.")
|
||||
arg3("--convert-ok", action="store_true",
|
||||
help="Allow reusing point clouds for uv clouds and vice versa. (does not account for other hparams)")
|
||||
arg3("--debug", action="store_true",
|
||||
help="Abort on failiure.")
|
||||
|
||||
arg5 = parser.add_argument_group("Shared modifiers").add_argument # brevity
|
||||
arg5("--scan-resolution", type=int, default=400,
|
||||
help="The resolution of the depth map rendered to sample points. Becomes x*x")
|
||||
|
||||
return parser
|
||||
|
||||
# entrypoint
|
||||
def cli(parser: argparse.ArgumentParser = make_parser()):
|
||||
args = parser.parse_args()
|
||||
if not any(getattr(args, k) for k in dir(args) if k.startswith("precompute_")) and not (args.list or args.list_pages):
|
||||
parser.error("no preprocessing target selected") # exits
|
||||
|
||||
config.DATA_PATH = Path(args.dir)
|
||||
obj_names = args.objects or read.list_object_names()
|
||||
|
||||
if args.list:
|
||||
print(*obj_names, sep="\n")
|
||||
parser.exit()
|
||||
|
||||
if args.list_pages is not None:
|
||||
print(*(
|
||||
f"--page {i} {args.list_pages} {obj_name}"
|
||||
for obj_name in obj_names
|
||||
for i in range(args.list_pages)
|
||||
), sep="\n")
|
||||
parser.exit()
|
||||
|
||||
if args.precompute_mesh_sv_scan_clouds:
|
||||
read.precompute_mesh_scan_point_clouds(
|
||||
obj_names,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sv_scan_uvs:
|
||||
read.precompute_mesh_scan_uvs(
|
||||
obj_names,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
fill_missing_points = args.fill_missing_uv_points,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
convert_ok = args.convert_ok,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
if args.precompute_mesh_sphere_scan:
|
||||
read.precompute_mesh_sphere_scan(
|
||||
obj_names,
|
||||
sphere_points = args.n_sphere_points,
|
||||
compute_miss_distances = args.compute_miss_distances,
|
||||
no_filter_backhits = args.no_filter_backhits,
|
||||
no_unit_sphere = args.no_unit_sphere,
|
||||
page = args.page,
|
||||
force = args.force,
|
||||
debug = args.debug,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
251
ifield/data/stanford/read.py
Normal file
251
ifield/data/stanford/read.py
Normal file
@@ -0,0 +1,251 @@
|
||||
from . import config
|
||||
from ..common import points
|
||||
from ..common import processing
|
||||
from ..common.scan import SingleViewScan, SingleViewUVScan
|
||||
from ..common.types import MalformedMesh
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Optional, Iterable
|
||||
from pathlib import Path
|
||||
import gzip
|
||||
import numpy as np
|
||||
import trimesh
|
||||
import trimesh.transformations as T
|
||||
|
||||
__doc__ = """
|
||||
Here are functions for reading and preprocessing shapenet benchmark data
|
||||
|
||||
There are essentially a few sets per object:
|
||||
"img" - meaning the RGBD images (none found in stanford)
|
||||
"mesh_scans" - meaning synthetic scans of a mesh
|
||||
"""
|
||||
|
||||
MESH_TRANSFORM_SKYWARD = T.rotation_matrix(np.pi/2, (1, 0, 0))
|
||||
MESH_TRANSFORM_CANONICAL = { # to gain a shared canonical orientation
|
||||
"armadillo" : T.rotation_matrix(np.pi, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
|
||||
"asian_dragon" : T.rotation_matrix(-np.pi/2, (0, 0, 1)) @ MESH_TRANSFORM_SKYWARD,
|
||||
"bunny" : MESH_TRANSFORM_SKYWARD,
|
||||
"dragon" : MESH_TRANSFORM_SKYWARD,
|
||||
"drill_bit" : MESH_TRANSFORM_SKYWARD,
|
||||
"happy_buddha" : MESH_TRANSFORM_SKYWARD,
|
||||
"lucy" : T.rotation_matrix(np.pi, (0, 0, 1)),
|
||||
"thai_statue" : MESH_TRANSFORM_SKYWARD,
|
||||
}
|
||||
|
||||
def list_object_names() -> list[str]:
|
||||
# downloaded only:
|
||||
return [
|
||||
i for i, v in config.MODELS.items()
|
||||
if (config.DATA_PATH / "meshes" / v.mesh_fname).is_file()
|
||||
]
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_sphere_coords(n_poses: int = 50) -> list[tuple[float, float]]: # (theta, phi)
|
||||
return points.generate_equidistant_sphere_points(n_poses, compute_sphere_coordinates=True)#, shift_theta=True
|
||||
|
||||
def mesh_scan_identifier(*, phi: float, theta: float) -> str:
|
||||
return (
|
||||
f"{'np'[theta>=0]}{abs(theta):.2f}"
|
||||
f"{'np'[phi >=0]}{abs(phi) :.2f}"
|
||||
).replace(".", "d")
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def list_mesh_scan_identifiers(n_poses: int = 50) -> list[str]:
|
||||
out = [
|
||||
mesh_scan_identifier(phi=phi, theta=theta)
|
||||
for theta, phi in list_mesh_scan_sphere_coords(n_poses)
|
||||
]
|
||||
assert len(out) == len(set(out))
|
||||
return out
|
||||
|
||||
# ===
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def read_mesh(obj_name: str) -> trimesh.Trimesh:
|
||||
path = config.DATA_PATH / "meshes" / config.MODELS[obj_name].mesh_fname
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"{obj_name = } -> {str(path) = }")
|
||||
try:
|
||||
if path.suffixes[-1] == ".gz":
|
||||
with gzip.open(path, "r") as f:
|
||||
mesh = trimesh.load(f, file_type="".join(path.suffixes[:-1])[1:])
|
||||
else:
|
||||
mesh = trimesh.load(path)
|
||||
except Exception as e:
|
||||
raise MalformedMesh(f"Trimesh raised: {e.__class__.__name__}: {e}") from e
|
||||
|
||||
# rotate to be upright in pyrender
|
||||
mesh.apply_transform(MESH_TRANSFORM_CANONICAL.get(obj_name, MESH_TRANSFORM_SKYWARD))
|
||||
|
||||
return mesh
|
||||
|
||||
# === single-view scan clouds
|
||||
|
||||
def compute_mesh_scan_point_cloud(
|
||||
obj_name : str,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
compute_miss_distances : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False, # this does not respect the other hparams
|
||||
**kw,
|
||||
) -> SingleViewScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_uv(obj_name, phi=phi, theta=theta).to_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(obj_name)
|
||||
return SingleViewScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
compute_miss_distances = compute_miss_distances,
|
||||
**kw,
|
||||
)
|
||||
|
||||
def precompute_mesh_scan_point_clouds(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_point_cloud_h5_fnames(obj_names, pose_identifiers)
|
||||
mlen = max(map(len, config.MODELS.keys()))
|
||||
pretty_identifiers = [
|
||||
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for obj_name in obj_names
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
@wraps(compute_mesh_scan_point_cloud)
|
||||
def computer(pretty_identifier: str) -> SingleViewScan:
|
||||
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_point_cloud(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
|
||||
if not file.exists(): raise FileNotFoundError(str(file))
|
||||
return SingleViewScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_point_cloud_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_clouds.h5"
|
||||
for obj_name in obj_names
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
# === single-view UV scan clouds
|
||||
|
||||
def compute_mesh_scan_uv(
|
||||
obj_name : str,
|
||||
*,
|
||||
phi : float,
|
||||
theta : float,
|
||||
compute_miss_distances : bool = False,
|
||||
fill_missing_points : bool = False,
|
||||
compute_normals : bool = True,
|
||||
convert_ok : bool = False,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
|
||||
if convert_ok:
|
||||
try:
|
||||
return read_mesh_scan_point_cloud(obj_name, phi=phi, theta=theta).to_uv_scan()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
mesh = read_mesh(obj_name)
|
||||
scan = SingleViewUVScan.from_mesh_single_view(mesh,
|
||||
phi = phi,
|
||||
theta = theta,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
if compute_miss_distances:
|
||||
scan.compute_miss_distances()
|
||||
if fill_missing_points:
|
||||
scan.fill_missing_points()
|
||||
|
||||
return scan
|
||||
|
||||
def precompute_mesh_scan_uvs(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_poses: int = 50, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
cam_poses = list_mesh_scan_sphere_coords(n_poses)
|
||||
pose_identifiers = list_mesh_scan_identifiers (n_poses)
|
||||
assert len(cam_poses) == len(pose_identifiers)
|
||||
paths = list_mesh_scan_uv_h5_fnames(obj_names, pose_identifiers)
|
||||
mlen = max(map(len, config.MODELS.keys()))
|
||||
pretty_identifiers = [
|
||||
f"{obj_name.ljust(mlen)} @ {i:>5} @ ({itentifier}: {theta:.2f}, {phi:.2f})"
|
||||
for obj_name in obj_names
|
||||
for i, (itentifier, (theta, phi)) in enumerate(zip(pose_identifiers, cam_poses))
|
||||
]
|
||||
mesh_cache = []
|
||||
@wraps(compute_mesh_scan_uv)
|
||||
def computer(pretty_identifier: str) -> SingleViewScan:
|
||||
obj_name, index, _ = map(str.strip, pretty_identifier.split("@"))
|
||||
theta, phi = cam_poses[int(index)]
|
||||
return compute_mesh_scan_uv(obj_name, phi=phi, theta=theta, _mesh_cache=mesh_cache, **kw)
|
||||
return processing.precompute_data(computer, pretty_identifiers, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_scan_uv(obj_name, *, identifier: str = None, phi: float = None, theta: float = None) -> SingleViewUVScan:
|
||||
if identifier is None:
|
||||
if phi is None or theta is None:
|
||||
raise ValueError("Provide either phi+theta or an identifier!")
|
||||
identifier = mesh_scan_identifier(phi=phi, theta=theta)
|
||||
file = config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
|
||||
if not file.exists(): raise FileNotFoundError(str(file))
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_scan_uv_h5_fnames(obj_names: Iterable[str], identifiers: Optional[Iterable[str]] = None, **kw) -> list[Path]:
|
||||
if identifiers is None:
|
||||
identifiers = list_mesh_scan_identifiers(**kw)
|
||||
return [
|
||||
config.DATA_PATH / "clouds" / obj_name / f"mesh_scan_{identifier}_uv.h5"
|
||||
for obj_name in obj_names
|
||||
for identifier in identifiers
|
||||
]
|
||||
|
||||
# === sphere-view (UV) scan clouds
|
||||
|
||||
def compute_mesh_sphere_scan(
|
||||
obj_name : str,
|
||||
*,
|
||||
compute_normals : bool = True,
|
||||
**kw,
|
||||
) -> SingleViewUVScan:
|
||||
mesh = read_mesh(obj_name)
|
||||
scan = SingleViewUVScan.from_mesh_sphere_view(mesh,
|
||||
compute_normals = compute_normals,
|
||||
**kw,
|
||||
)
|
||||
return scan
|
||||
|
||||
def precompute_mesh_sphere_scan(obj_names, *, page: tuple[int, int] = (0, 1), force: bool = False, debug: bool = False, n_points: int = 4000, **kw):
|
||||
"precomputes all single-view scan clouds and stores them as HDF5 datasets"
|
||||
paths = list_mesh_sphere_scan_h5_fnames(obj_names)
|
||||
@wraps(compute_mesh_sphere_scan)
|
||||
def computer(obj_name: str) -> SingleViewScan:
|
||||
return compute_mesh_sphere_scan(obj_name, **kw)
|
||||
return processing.precompute_data(computer, obj_names, paths, page=page, force=force, debug=debug)
|
||||
|
||||
def read_mesh_mesh_sphere_scan(obj_name) -> SingleViewUVScan:
|
||||
file = config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
|
||||
if not file.exists(): raise FileNotFoundError(str(file))
|
||||
return SingleViewUVScan.from_h5_file(file)
|
||||
|
||||
def list_mesh_sphere_scan_h5_fnames(obj_names: Iterable[str]) -> list[Path]:
|
||||
return [
|
||||
config.DATA_PATH / "clouds" / obj_name / "mesh_sphere_scan.h5"
|
||||
for obj_name in obj_names
|
||||
]
|
||||
Reference in New Issue
Block a user