Add code
This commit is contained in:
101
ifield/models/orthogonal_plane.py
Normal file
101
ifield/models/orthogonal_plane.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from .. import param
|
||||
from ..modules import fc
|
||||
from ..utils import geometry
|
||||
from ..utils.helpers import compose
|
||||
from textwrap import indent, dedent
|
||||
from torch import nn, Tensor
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
class OrthogonalPlaneNet(nn.Module):
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features : int,
|
||||
latent_features : int,
|
||||
hidden_features : int,
|
||||
hidden_layers : int,
|
||||
**kw,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.fc = fc.FCBlock(
|
||||
in_features = in_features,
|
||||
hidden_layers = hidden_layers,
|
||||
hidden_features = hidden_features,
|
||||
out_features = 2, # (plane_offset, is_intersecting)
|
||||
outermost_linear = True,
|
||||
latent_features = latent_features,
|
||||
**kw,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_conditioned(self):
|
||||
return self.fc.is_conditioned
|
||||
|
||||
@classmethod
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str:
|
||||
yield param.make_jinja_template(cls, top_level=top_level, exclude_list=exclude_list, **kw)
|
||||
yield param.make_jinja_template(fc.FCBlock, top_level=False, exclude_list={
|
||||
"in_features",
|
||||
"hidden_layers",
|
||||
"hidden_features",
|
||||
"out_features",
|
||||
"outermost_linear",
|
||||
})
|
||||
|
||||
def forward(self, x: Tensor, z: Optional[Tensor] = None) -> Tensor:
|
||||
if __debug__ and self.is_conditioned and z is None:
|
||||
warnings.warn(f"{self.__class__.__qualname__} is conditioned, but the forward pass was not supplied with a conditioning tensor.")
|
||||
return self.fc(x, z)
|
||||
|
||||
@staticmethod
|
||||
def compute_intersections(
|
||||
ray_origins : Tensor, # (..., 3)
|
||||
ray_dirs : Tensor, # (..., 3)
|
||||
predictions : Tensor, # (..., 2)
|
||||
*,
|
||||
normalize_origins = True,
|
||||
return_signed_displacements = False,
|
||||
allow_nans = False, # MARF compat
|
||||
atom_random_prob = None, # MARF compat
|
||||
atom_dropout_prob = None, # MARF compat
|
||||
) -> tuple[(Tensor,)*5]:
|
||||
assert ray_origins.shape[:-1] == ray_dirs.shape[:-1] == predictions.shape[:-1], \
|
||||
(ray_origins.shape, ray_dirs.shape, predictions.shape)
|
||||
assert predictions.shape[-1] == 2, \
|
||||
predictions.shape
|
||||
|
||||
assert not allow_nans
|
||||
|
||||
if normalize_origins:
|
||||
ray_origins = geometry.project_point_on_ray(0, ray_origins, ray_dirs)
|
||||
|
||||
# unpack predictions
|
||||
signed_displacements = predictions[..., 0]
|
||||
is_intersecting = predictions[..., 1]
|
||||
|
||||
# compute intersections
|
||||
intersections = ray_origins - signed_displacements[..., None] * ray_dirs
|
||||
|
||||
return (
|
||||
intersections,
|
||||
is_intersecting,
|
||||
*((signed_displacements,) if return_signed_displacements else ()),
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
OrthogonalPlaneNet.__doc__ = __doc__ = f"""
|
||||
{dedent(OrthogonalPlaneNet.__doc__).strip()}
|
||||
|
||||
# Config template:
|
||||
|
||||
```yaml
|
||||
{OrthogonalPlaneNet.make_jinja_template()}
|
||||
```
|
||||
"""
|
||||
Reference in New Issue
Block a user