from .. import param from ..modules.dtype import DtypeMixin from ..utils import geometry from ..utils.helpers import compose from ..utils.loss import Schedulable, ensure_schedulables, HParamSchedule, HParamScheduleBase, Linear from ..utils.operators import diff from .conditioning import RequiresConditioner, AutoDecoderModuleMixin from .medial_atoms import MedialAtomNet from .orthogonal_plane import OrthogonalPlaneNet from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor from torch.nn import functional as F from typing import TypedDict, Literal, Union, Hashable, Optional import pytorch_lightning as pl import torch import os LOG_ALL_METRICS = bool(int(os.environ.get("IFIELD_LOG_ALL_METRICS", "1"))) if __debug__: def broadcast_tensors(*tensors: torch.Tensor) -> list[torch.Tensor]: try: return torch.broadcast_tensors(*tensors) except RuntimeError as e: shapes = ", ".join(f"{chr(c)}.size={tuple(t.shape)}" for c, t in enumerate(tensors, ord("a"))) raise ValueError(f"Could not broadcast tensors {shapes}.\n{str(e)}") else: broadcast_tensors = torch.broadcast_tensors class ForwardDepthMapsBatch(TypedDict): cam2world : Tensor # (B, 4, 4) uv : Tensor # (B, H, W) intrinsics : Tensor # (B, 3, 3) class ForwardScanRaysBatch(TypedDict): origins : Tensor # (B, H, W, 3) or (B, 3) dirs : Tensor # (B, H, W, 3) class LossBatch(TypedDict): hits : Tensor # (B, H, W) dtype=bool miss : Tensor # (B, H, W) dtype=bool depths : Tensor # (B, H, W) normals : Tensor # (B, H, W, 3) NaN if not hit distances : Tensor # (B, H, W, 1) NaN if not miss class LabeledBatch(TypedDict): z_uid : list[Hashable] ForwardBatch = Union[ForwardDepthMapsBatch, ForwardScanRaysBatch] TrainingBatch = Union[ForwardBatch, LossBatch, LabeledBatch] IntersectionMode = Literal[ "medial_sphere", "orthogonal_plane", ] class IntersectionFieldModel(pl.LightningModule, RequiresConditioner, DtypeMixin): net: Union[MedialAtomNet, OrthogonalPlaneNet] @ensure_schedulables def __init__(self, # mode input_mode : geometry.RayEmbedding = "plucker", output_mode : IntersectionMode = "medial_sphere", # network latent_features : int = 256, hidden_features : int = 512, hidden_layers : int = 8, improve_miss_grads: bool = True, normalize_ray_dirs: bool = False, # the dataset is usually already normalized, but this could still be important for backprop # orthogonal plane loss_hit_cross_entropy : Schedulable = 1.0, # medial atoms loss_intersection : Schedulable = 1, loss_intersection_l2 : Schedulable = 0, loss_intersection_proj : Schedulable = 0, loss_intersection_proj_l2 : Schedulable = 0, loss_normal_cossim : Schedulable = 0.25, # supervise target normal cosine similarity loss_normal_euclid : Schedulable = 0, # supervise target normal l2 distance loss_normal_cossim_proj : Schedulable = 0, # supervise target normal cosine similarity loss_normal_euclid_proj : Schedulable = 0, # supervise target normal l2 distance loss_hit_nodistance_l1 : Schedulable = 0, # constrain no miss distance for hits loss_hit_nodistance_l2 : Schedulable = 32, # constrain no miss distance for hits loss_miss_distance_l1 : Schedulable = 0, # supervise target miss distance for misses loss_miss_distance_l2 : Schedulable = 0, # supervise target miss distance for misses loss_inscription_hits : Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray loss_inscription_hits_l2: Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray loss_inscription_miss : Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray loss_inscription_miss_l2: Schedulable = 0, # Penalize atom candidates using the supervision data of a different ray loss_sphere_grow_reg : Schedulable = 0, # maximialize sphere size loss_sphere_grow_reg_hit: Schedulable = 0, # maximialize sphere size loss_embedding_norm : Schedulable = "0.01**2 * Linear(15)", # DeepSDF schedules over 150 epochs. DeepSDF use 0.01**2, irobot uses 0.04**2 loss_multi_view_reg : Schedulable = 0, # minimize gradient w.r.t. delta ray dir, when ray origin = intersection loss_atom_centroid_norm_std_reg : Schedulable = 0, # minimize per-atom centroid std # optimization opt_learning_rate : Schedulable = 1e-5, opt_weight_decay : float = 0, opt_warmup : float = 0, **kw, ): super().__init__() opt_warmup = Linear(opt_warmup) opt_warmup._param_name = "opt_warmup" self.save_hyperparameters() if "half" in input_mode: assert output_mode == "medial_sphere" and kw.get("n_atoms", 1) > 1 assert output_mode in ["medial_sphere", "orthogonal_plane"] assert opt_weight_decay >= 0, opt_weight_decay if output_mode == "orthogonal_plane": self.net = OrthogonalPlaneNet( in_features = self.n_input_embedding_features, hidden_layers = hidden_layers, hidden_features = hidden_features, latent_features = latent_features, **kw, ) elif output_mode == "medial_sphere": self.net = MedialAtomNet( in_features = self.n_input_embedding_features, hidden_layers = hidden_layers, hidden_features = hidden_features, latent_features = latent_features, **kw, ) def on_fit_start(self): if __debug__: for k, v in self.hparams.items(): if isinstance(v, HParamScheduleBase): v.assert_positive(self.trainer.max_epochs) @property def n_input_embedding_features(self) -> int: return geometry.ray_input_embedding_length(self.hparams.input_mode) @property def n_latent_features(self) -> int: return self.hparams.latent_features @property def latent_embeddings_init_std(self) -> float: return 0.01 @property def is_conditioned(self): return self.net.is_conditioned @property def is_double_backprop(self) -> bool: return self.is_double_backprop_origins or self.is_double_backprop_dirs @property def is_double_backprop_origins(self) -> bool: prif = self.hparams.output_mode == "orthogonal_plane" return prif and self.hparams.loss_normal_cossim @property def is_double_backprop_dirs(self) -> bool: return self.hparams.loss_multi_view_reg @classmethod @compose("\n".join) def make_jinja_template(cls, *, exclude_list: set[str] = {}, top_level: bool = True, **kw) -> str: yield param.make_jinja_template(cls, top_level=top_level, **kw) yield MedialAtomNet.make_jinja_template(top_level=False, exclude_list={ "in_features", "hidden_layers", "hidden_features", "latent_features", }) def batch2rays(self, batch: ForwardBatch) -> tuple[Tensor, Tensor]: if "uv" in batch: raise NotImplementedError assert not (self.hparams.loss_multi_view_reg and self.training) ray_origins, \ ray_dirs, \ = geometry.camera_uv_to_rays( cam2world = batch["cam2world"], uv = batch["uv"], intrinsics = batch["intrinsics"], ) else: ray_origins = batch["points" if self.hparams.loss_multi_view_reg and self.training else "origins"] ray_dirs = batch["dirs"] return ray_origins, ray_dirs def forward(self, batch : ForwardBatch, z : Optional[Tensor] = None, # latent code *, return_input : bool = False, allow_nans : bool = False, # in output **kw, ) -> tuple[torch.Tensor, ...]: ( ray_origins, # (B, 3) ray_dirs, # (B, H, W, 3) ) = self.batch2rays(batch) # Ensure rays are normalized # NOTICE: this is slow, make sure to train with optimizations! assert ray_dirs.detach().norm(dim=-1).allclose(torch.ones(ray_dirs.shape[:-1], **self.device_and_dtype)),\ ray_dirs.detach().norm(dim=-1) if ray_origins.ndim + 2 == ray_dirs.ndim: ray_origins = ray_origins[..., None, None, :] ray_origins, ray_dirs = broadcast_tensors(ray_origins, ray_dirs) if self.is_double_backprop and self.training: if self.is_double_backprop_dirs: ray_dirs.requires_grad = True if self.is_double_backprop_origins: ray_origins.requires_grad = True assert ray_origins.requires_grad or ray_dirs.requires_grad input = geometry.ray_input_embedding( ray_origins, ray_dirs, mode = self.hparams.input_mode, normalize_dirs = self.hparams.normalize_ray_dirs, is_training = self.training, ) assert not input.detach().isnan().any() predictions = self.net(input, z) intersections = self.net.compute_intersections( ray_origins, ray_dirs, predictions, allow_nans = allow_nans and not self.training, **kw ) if return_input: return ray_origins, ray_dirs, input, intersections else: return intersections def training_step(self, batch: TrainingBatch, batch_idx: int, *, is_validation=False) -> Tensor: z = self.encode(batch) if self.is_conditioned else None assert self.is_conditioned or len(set(batch["z_uid"])) <= 1, \ f"Network is unconditioned, but the batch has multiple uids: {set(batch['z_uid'])!r}" # unpack target_hits = batch["hits"] # (B, H, W) dtype=bool target_miss = batch["miss"] # (B, H, W) dtype=bool target_points = batch["points"] # (B, H, W, 3) target_normals = batch["normals"] # (B, H, W, 3) NaN if not hit target_distances = batch["distances"] # (B, H, W) NaN if not miss assert not target_normals [target_hits].isnan().any() assert not target_distances[target_miss].isnan().any() target_normals[target_normals.isnan()] = 0 assert not target_normals .isnan().any() # make z fit batch scheme if z is not None: z = z[..., None, None, :] losses = {} metrics = {} zeros = torch.zeros_like(target_distances) if self.hparams.output_mode == "medial_sphere": assert isinstance(self.net, MedialAtomNet) ray_origins, ray_dirs, plucker, ( depths, # (...) float, projection if not hit silhouettes, # (...) float intersections, # (..., 3) float, projection or NaN if not hit intersection_normals, # (..., 3) float, rejection or NaN if not hit is_intersecting, # (...) bool, true if hit sphere_centers, # (..., 3) network output sphere_radii, # (...) network output atom_indices, all_intersections, # (..., N_ATOMS) float, projection or NaN if not hit all_intersection_normals, # (..., N_ATOMS, 3) float, rejection or NaN if not hit all_depths, # (..., N_ATOMS) float, projection if not hit all_silhouettes, # (..., N_ATOMS, 3) float, projection or NaN if not hit all_is_intersecting, # (..., N_ATOMS) bool, true if hit all_sphere_centers, # (..., N_ATOMS, 3) network output all_sphere_radii, # (..., N_ATOMS) network output ) = self(batch, z, intersections_only = False, return_all_atoms = True, allow_nans = False, return_input = True, improve_miss_grads = True, ) # target hit supervision if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection: # scores true hits losses["loss_intersection"] = ( (target_points - intersections).norm(dim=-1) ).where(target_hits & is_intersecting, zeros).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_l2: # scores true hits losses["loss_intersection_l2"] = ( (target_points - intersections).pow(2).sum(dim=-1) ).where(target_hits & is_intersecting, zeros).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_proj: # scores misses as if they were hits, using the projection losses["loss_intersection_proj"] = ( (target_points - intersections).norm(dim=-1) ).where(target_hits, zeros).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_proj_l2: # scores misses as if they were hits, using the projection losses["loss_intersection_proj_l2"] = ( (target_points - intersections).pow(2).sum(dim=-1) ).where(target_hits, zeros).mean() # target hit normal supervision if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_cossim: # scores true hits losses["loss_normal_cossim"] = ( 1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1) ).where(target_hits & is_intersecting, zeros).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_euclid: # scores true hits losses["loss_normal_euclid"] = ( (target_normals - intersection_normals).norm(dim=-1) ).where(target_hits & is_intersecting, zeros).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_cossim_proj: # scores misses as if they were hits losses["loss_normal_cossim_proj"] = ( 1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1) ).where(target_hits, zeros).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_normal_euclid_proj: # scores misses as if they were hits losses["loss_normal_euclid_proj"] = ( (target_normals - intersection_normals).norm(dim=-1) ).where(target_hits, zeros).mean() # target sufficient hit radius if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_nodistance_l1: # ensures hits become hits, instead of relying on the projection being right losses["loss_hit_nodistance_l1"] = ( silhouettes ).where(target_hits & (silhouettes > 0), zeros).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_nodistance_l2: # ensures hits become hits, instead of relying on the projection being right losses["loss_hit_nodistance_l2"] = ( silhouettes ).where(target_hits & (silhouettes > 0), zeros).pow(2).mean() # target miss supervision if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_miss_distance_l1: # only positive misses reinforcement losses["loss_miss_distance_l1"] = ( target_distances - silhouettes ).where(target_miss, zeros).abs().mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_miss_distance_l2: # only positive misses reinforcement losses["loss_miss_distance_l2"] = ( target_distances - silhouettes ).where(target_miss, zeros).pow(2).mean() # incentivise maximal spheres if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_sphere_grow_reg: # all atoms losses["loss_sphere_grow_reg"] = ((all_sphere_radii.detach() + 1) - all_sphere_radii).abs().mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_sphere_grow_reg_hit: # true hits only losses["loss_sphere_grow_reg_hit"] = ((sphere_radii.detach() + 1) - sphere_radii).where(target_hits & is_intersecting, zeros).abs().mean() # spherical latent prior if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_embedding_norm: losses["loss_embedding_norm"] = self.latent_embeddings.norm(dim=-1).mean() is_grad_enabled = torch.is_grad_enabled() # multi-view regularization: atom should not change when view changes if self.hparams.loss_multi_view_reg and is_grad_enabled: assert ray_dirs.requires_grad, ray_dirs assert plucker.requires_grad, plucker assert intersections.grad_fn is not None assert intersection_normals.grad_fn is not None *center_grads, radii_grads = diff.gradients( sphere_centers[..., 0], sphere_centers[..., 1], sphere_centers[..., 2], sphere_radii, wrt=ray_dirs, ) losses["loss_multi_view_reg"] = ( sum( i.pow(2).sum(dim=-1) for i in center_grads ).where(target_hits & is_intersecting, zeros).mean() + radii_grads.pow(2).sum(dim=-1) .where(target_hits & is_intersecting, zeros).mean() ) # minimize the volume spanned by each atom if self.hparams.loss_atom_centroid_norm_std_reg and self.net.n_atoms > 1: assert len(all_sphere_centers.shape) == 5, all_sphere_centers.shape losses["loss_atom_centroid_norm_std_reg"] \ = (( all_sphere_centers - all_sphere_centers .mean(dim=(1, 2), keepdim=True) ).pow(2).sum(dim=-1) - 0.05**2).clamp(0, None).mean() # prif is l1, LSMAT is l2 if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits or self.hparams.loss_inscription_miss or self.hparams.loss_inscription_hits_l2 or self.hparams.loss_inscription_miss_l2: b = target_hits.shape[0] # number of objects n = target_hits.shape[1:].numel() # rays per object perm = torch.randperm(n, device=self.device) # ray2ray permutation flatten = dict(start_dim=1, end_dim=len(target_hits.shape) - 1) ( inscr_sphere_center_projs, # (b, n, n_atoms, 3) inscr_intersections_near, # (b, n, n_atoms, 3) inscr_intersections_far, # (b, n, n_atoms, 3) inscr_is_intersecting, # (b, n, n_atoms) dtype=bool ) = geometry.ray_sphere_intersect( ray_origins.flatten(**flatten)[:, perm, None, :], ray_dirs .flatten(**flatten)[:, perm, None, :], all_sphere_centers.flatten(**flatten), all_sphere_radii .flatten(**flatten), return_parts = True, allow_nans = False, improve_miss_grads = self.hparams.improve_miss_grads, ) assert inscr_sphere_center_projs.shape == (b, n, self.net.n_atoms, 3), \ (inscr_sphere_center_projs.shape, (b, n, self.net.n_atoms, 3)) inscr_silhouettes = ( inscr_sphere_center_projs - all_sphere_centers.flatten(**flatten) ).norm(dim=-1) - all_sphere_radii.flatten(**flatten) loss_inscription_hits = ( ( (inscr_intersections_near - target_points.flatten(**flatten)[:, perm, None, :]) * ray_dirs.flatten(**flatten)[:, perm, None, :] ).sum(dim=-1) ).where(target_hits.flatten(**flatten)[:, perm, None] & inscr_is_intersecting, torch.zeros(inscr_intersections_near.shape[:-1], **self.device_and_dtype), ).clamp(None, 0) loss_inscription_miss = ( inscr_silhouettes - target_distances.flatten(**flatten)[:, perm, None] ).where(target_miss.flatten(**flatten)[:, perm, None], torch.zeros_like(inscr_silhouettes) ).clamp(None, 0) if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits: losses["loss_inscription_hits"] = loss_inscription_hits.neg().mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_miss: losses["loss_inscription_miss"] = loss_inscription_miss.neg().mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_hits_l2: losses["loss_inscription_hits_l2"] = loss_inscription_hits.pow(2).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_inscription_miss_l2: losses["loss_inscription_miss_l2"] = loss_inscription_miss.pow(2).mean() # metrics metrics["iou"] = ( ((~target_miss) & is_intersecting.detach()).sum() / ((~target_miss) | is_intersecting.detach()).sum() ) metrics["radii"] = sphere_radii.detach().mean() # with the constant applied pressure, we need to measure it this way instead elif self.hparams.output_mode == "orthogonal_plane": assert isinstance(self.net, OrthogonalPlaneNet) ray_origins, ray_dirs, input_embedding, ( intersections, # (..., 3) dtype=float is_intersecting, # (...) dtype=float ) = self(batch, z, return_input=True, normalize_origins=True) if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection: losses["loss_intersection"] = ( (intersections - target_points).norm(dim=-1) ).where(target_hits, zeros).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_intersection_l2: losses["loss_intersection_l2"] = ( (intersections - target_points).pow(2).sum(dim=-1) ).where(target_hits, zeros).mean() if (__debug__ or LOG_ALL_METRICS) or self.hparams.loss_hit_cross_entropy: losses["loss_hit_cross_entropy"] = ( F.binary_cross_entropy_with_logits(is_intersecting, (~target_miss).to(self.dtype)) ).mean() if self.hparams.loss_normal_cossim and torch.is_grad_enabled(): jac = diff.jacobian(intersections, ray_origins) intersection_normals = self.compute_normals_from_intersection_origin_jacobian(jac, ray_dirs) losses["loss_normal_cossim"] = ( 1 - torch.cosine_similarity(target_normals, intersection_normals, dim=-1) ).where(target_hits, zeros).mean() if self.hparams.loss_normal_euclid and torch.is_grad_enabled(): jac = diff.jacobian(intersections, ray_origins) intersection_normals = self.compute_normals_from_intersection_origin_jacobian(jac, ray_dirs) losses["loss_normal_euclid"] = ( (target_normals - intersection_normals).norm(dim=-1) ).where(target_hits, zeros).mean() if self.hparams.loss_multi_view_reg and torch.is_grad_enabled(): assert ray_dirs .requires_grad, ray_dirs assert intersections.grad_fn is not None grads = diff.gradients( intersections[..., 0], intersections[..., 1], intersections[..., 2], wrt=ray_dirs, ) losses["loss_multi_view_reg"] = sum( i.pow(2).sum(dim=-1) for i in grads ).where(target_hits, zeros).mean() metrics["iou"] = ( ((~target_miss) & (is_intersecting>0.5).detach()).sum() / ((~target_miss) | (is_intersecting>0.5).detach()).sum() ) else: raise NotImplementedError(self.hparams.output_mode) # output losses and metrics # apply scaling: losses_unscaled = losses.copy() # shallow copy for k in list(losses.keys()): assert losses[k].numel() == 1, f"losses[{k!r}] shape: {losses[k].shape}" val_schedule: HParamSchedule = self.hparams[k] val = val_schedule.get(self) if val == 0: if (__debug__ or LOG_ALL_METRICS) and val_schedule.is_const: del losses[k] # it was only added for unscaled logging, do not backprop else: losses[k] = 0 elif val != 1: losses[k] = losses[k] * val if not losses: raise MisconfigurationException("no loss was computed") losses["loss"] = sum(losses.values()) * self.hparams.opt_warmup.get(self) losses.update({f"unscaled_{k}": v.detach() for k, v in losses_unscaled.items()}) losses.update({f"metric_{k}": v.detach() for k, v in metrics.items()}) return losses # used by pl.callbacks.EarlyStopping, via cli.py @property def metric_early_stop(self): return ( "unscaled_loss_intersection_proj" if self.hparams.output_mode == "medial_sphere" else "unscaled_loss_intersection" ) def validation_step(self, batch: TrainingBatch, batch_idx: int) -> dict[str, Tensor]: losses = self.training_step(batch, batch_idx, is_validation=True) return losses def configure_optimizers(self): adam = torch.optim.Adam(self.parameters(), lr=1 if not self.hparams.opt_learning_rate.is_const else self.hparams.opt_learning_rate.get_train_value(0), weight_decay=self.hparams.opt_weight_decay) schedules = [] if not self.hparams.opt_learning_rate.is_const: schedules = [ torch.optim.lr_scheduler.LambdaLR(adam, lambda epoch: self.hparams.opt_learning_rate.get_train_value(epoch), ), ] return [adam], schedules @property def example_input_array(self) -> tuple[dict[str, Tensor], Tensor]: return ( { # see self.batch2rays "origins" : torch.zeros(1, 3), # most commonly used "points" : torch.zeros(1, 3), # used if self.training and self.hparams.loss_multi_view_reg "dirs" : torch.ones(1, 3) * torch.rsqrt(torch.tensor(3)), }, torch.ones(1, self.hparams.latent_features), ) @staticmethod def compute_normals_from_intersection_origin_jacobian(origin_jac: Tensor, ray_dirs: Tensor) -> Tensor: normals = sum(( torch.cross(origin_jac[..., 0], origin_jac[..., 1], dim=-1) * -ray_dirs[..., [2]], torch.cross(origin_jac[..., 1], origin_jac[..., 2], dim=-1) * -ray_dirs[..., [0]], torch.cross(origin_jac[..., 2], origin_jac[..., 0], dim=-1) * -ray_dirs[..., [1]], )) return normals / normals.norm(dim=-1, keepdim=True) class IntersectionFieldAutoDecoderModel(IntersectionFieldModel, AutoDecoderModuleMixin): def encode(self, batch: LabeledBatch) -> Tensor: assert not isinstance(self.trainer.strategy, pl.strategies.DataParallelStrategy) return self[batch["z_uid"]] # [N, Z_n]