102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
from .. import param
|
|
from ..modules import fc
|
|
from ..utils import geometry
|
|
from ..utils.helpers import compose
|
|
from textwrap import indent, dedent
|
|
from torch import nn, Tensor
|
|
from typing import Optional
|
|
import warnings
|
|
|
|
class OrthogonalPlaneNet(nn.Module):
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_features : int,
|
|
latent_features : int,
|
|
hidden_features : int,
|
|
hidden_layers : int,
|
|
**kw,
|
|
):
|
|
super().__init__()
|
|
|
|
self.fc = fc.FCBlock(
|
|
in_features = in_features,
|
|
hidden_layers = hidden_layers,
|
|
hidden_features = hidden_features,
|
|
out_features = 2, # (plane_offset, is_intersecting)
|
|
outermost_linear = True,
|
|
latent_features = latent_features,
|
|
**kw,
|
|
)
|
|
|
|
@property
|
|
def is_conditioned(self):
|
|
return self.fc.is_conditioned
|
|
|
|
@classmethod
|
|
@compose("\n".join)
|
|
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
|
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list, **kw)
|
|
yield param.make_jinja_template(fc.FCBlock, top_level=False, exclude_list={
|
|
"in_features",
|
|
"hidden_layers",
|
|
"hidden_features",
|
|
"out_features",
|
|
"outermost_linear",
|
|
})
|
|
|
|
def forward(self, x: Tensor, z: Optional[Tensor] = None) -> Tensor:
|
|
if __debug__ and self.is_conditioned and z is None:
|
|
warnings.warn(f"{self.__class__.__qualname__} is conditioned, but the forward pass was not supplied with a conditioning tensor.")
|
|
return self.fc(x, z)
|
|
|
|
@staticmethod
|
|
def compute_intersections(
|
|
ray_origins : Tensor, # (..., 3)
|
|
ray_dirs : Tensor, # (..., 3)
|
|
predictions : Tensor, # (..., 2)
|
|
*,
|
|
normalize_origins = True,
|
|
return_signed_displacements = False,
|
|
allow_nans = False, # MARF compat
|
|
atom_random_prob = None, # MARF compat
|
|
atom_dropout_prob = None, # MARF compat
|
|
) -> tuple[(Tensor,)*5]:
|
|
assert ray_origins.shape[:-1] == ray_dirs.shape[:-1] == predictions.shape[:-1], \
|
|
(ray_origins.shape, ray_dirs.shape, predictions.shape)
|
|
assert predictions.shape[-1] == 2, \
|
|
predictions.shape
|
|
|
|
assert not allow_nans
|
|
|
|
if normalize_origins:
|
|
ray_origins = geometry.project_point_on_ray(0, ray_origins, ray_dirs)
|
|
|
|
# unpack predictions
|
|
signed_displacements = predictions[..., 0]
|
|
is_intersecting = predictions[..., 1]
|
|
|
|
# compute intersections
|
|
intersections = ray_origins - signed_displacements[..., None] * ray_dirs
|
|
|
|
return (
|
|
intersections,
|
|
is_intersecting,
|
|
*((signed_displacements,) if return_signed_displacements else ()),
|
|
)
|
|
|
|
|
|
|
|
|
|
OrthogonalPlaneNet.__doc__ = __doc__ = f"""
|
|
{dedent(OrthogonalPlaneNet.__doc__).strip()}
|
|
|
|
# Config template:
|
|
|
|
```yaml
|
|
{OrthogonalPlaneNet.make_jinja_template()}
|
|
```
|
|
"""
|