Files
marf/ifield/models/orthogonal_plane.py
2025-01-09 15:43:11 +01:00

102 lines
3.1 KiB
Python

from .. import param
from ..modules import fc
from ..utils import geometry
from ..utils.helpers import compose
from textwrap import indent, dedent
from torch import nn, Tensor
from typing import Optional
import warnings
class OrthogonalPlaneNet(nn.Module):
"""
"""
def __init__(self,
in_features : int,
latent_features : int,
hidden_features : int,
hidden_layers : int,
**kw,
):
super().__init__()
self.fc = fc.FCBlock(
in_features = in_features,
hidden_layers = hidden_layers,
hidden_features = hidden_features,
out_features = 2, # (plane_offset, is_intersecting)
outermost_linear = True,
latent_features = latent_features,
**kw,
)
@property
def is_conditioned(self):
return self.fc.is_conditioned
@classmethod
@compose("\n".join)
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list, **kw)
yield param.make_jinja_template(fc.FCBlock, top_level=False, exclude_list={
"in_features",
"hidden_layers",
"hidden_features",
"out_features",
"outermost_linear",
})
def forward(self, x: Tensor, z: Optional[Tensor] = None) -> Tensor:
if __debug__ and self.is_conditioned and z is None:
warnings.warn(f"{self.__class__.__qualname__} is conditioned, but the forward pass was not supplied with a conditioning tensor.")
return self.fc(x, z)
@staticmethod
def compute_intersections(
ray_origins : Tensor, # (..., 3)
ray_dirs : Tensor, # (..., 3)
predictions : Tensor, # (..., 2)
*,
normalize_origins = True,
return_signed_displacements = False,
allow_nans = False, # MARF compat
atom_random_prob = None, # MARF compat
atom_dropout_prob = None, # MARF compat
) -> tuple[(Tensor,)*5]:
assert ray_origins.shape[:-1] == ray_dirs.shape[:-1] == predictions.shape[:-1], \
(ray_origins.shape, ray_dirs.shape, predictions.shape)
assert predictions.shape[-1] == 2, \
predictions.shape
assert not allow_nans
if normalize_origins:
ray_origins = geometry.project_point_on_ray(0, ray_origins, ray_dirs)
# unpack predictions
signed_displacements = predictions[..., 0]
is_intersecting = predictions[..., 1]
# compute intersections
intersections = ray_origins - signed_displacements[..., None] * ray_dirs
return (
intersections,
is_intersecting,
*((signed_displacements,) if return_signed_displacements else ()),
)
OrthogonalPlaneNet.__doc__ = __doc__ = f"""
{dedent(OrthogonalPlaneNet.__doc__).strip()}
# Config template:
```yaml
{OrthogonalPlaneNet.make_jinja_template()}
```
"""