Files
marf/experiments/marf.py
2025-01-09 15:43:11 +01:00

625 lines
33 KiB
Python
Executable File

#!/usr/bin/env python3
from abc import ABC, abstractmethod
from argparse import Namespace
from collections import defaultdict
from datetime import datetime
from ifield import logging
from ifield.cli import CliInterface
from ifield.data.common.scan import SingleViewUVScan
from ifield.data.coseg import read as coseg_read
from ifield.data.stanford import read as stanford_read
from ifield.datasets import stanford, coseg, common
from ifield.models import intersection_fields
from ifield.utils.operators import diff
from ifield.viewer.ray_field import ModelViewer
from munch import Munch
from pathlib import Path
from pytorch3d.loss.chamfer import chamfer_distance
from pytorch_lightning.utilities import rank_zero_only
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from trimesh import Trimesh
from typing import Union
import builtins
import itertools
import json
import numpy as np
import pytorch_lightning as pl
import rich
import rich.pretty
import statistics
import torch
pl.seed_everything(31337)
torch.set_float32_matmul_precision('medium')
IField = intersection_fields.IntersectionFieldAutoDecoderModel # brevity
class RayFieldAdDataModuleBase(pl.LightningDataModule, ABC):
@property
@abstractmethod
def observation_ids(self) -> list[str]:
...
@abstractmethod
def mk_ad_dataset(self) -> common.AutodecoderDataset:
...
@staticmethod
@abstractmethod
def get_trimesh_from_uid(uid) -> Trimesh:
...
@staticmethod
@abstractmethod
def get_sphere_scan_from_uid(uid) -> SingleViewUVScan:
...
def setup(self, stage=None):
assert stage in ["fit", None] # fit is for train/val, None is for all. "test" not supported ATM
if not self.hparams.data_dir is None:
coseg.config.DATA_PATH = self.hparams.data_dir
step = self.hparams.step # brevity
dataset = self.mk_ad_dataset()
n_items_pre_step_mapping = len(dataset)
if step > 1:
dataset = common.TransformExtendedDataset(dataset)
for sx in range(step):
for sy in range(step):
def make_slicer(sx, sy, step) -> callable: # the closure is required
if step > 1:
return lambda t: t[sx::step, sy::step]
else:
return lambda t: t
@dataset.map(slicer=make_slicer(sx, sy, step))
def unpack(sample: tuple[str, SingleViewUVScan], slicer: callable):
scan: SingleViewUVScan = sample[1]
assert not scan.hits.shape[0] % step, f"{scan.hits.shape[0]=} not divisible by {step=}"
assert not scan.hits.shape[1] % step, f"{scan.hits.shape[1]=} not divisible by {step=}"
return {
"z_uid" : sample[0],
"origins" : scan.cam_pos,
"dirs" : slicer(scan.ray_dirs),
"points" : slicer(scan.points),
"hits" : slicer(scan.hits),
"miss" : slicer(scan.miss),
"normals" : slicer(scan.normals),
"distances" : slicer(scan.distances),
}
# Split each object into train/val with SampleSplit
n_items = len(dataset)
n_val = int(n_items * self.hparams.val_fraction)
n_train = n_items - n_val
self.generator = torch.Generator().manual_seed(self.hparams.prng_seed)
# split the dataset such that all steps are in same part
assert n_items == n_items_pre_step_mapping * step * step, (n_items, n_items_pre_step_mapping, step)
indices = [
i*step*step + sx*step + sy
for i in torch.randperm(n_items_pre_step_mapping, generator=self.generator).tolist()
for sx in range(step)
for sy in range(step)
]
self.dataset_train = Subset(dataset, sorted(indices[:n_train], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
self.dataset_val = Subset(dataset, sorted(indices[n_train:n_train+n_val], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
assert len(self.dataset_train) % self.hparams.batch_size == 0
assert len(self.dataset_val) % self.hparams.batch_size == 0
def train_dataloader(self):
return DataLoader(self.dataset_train,
batch_size = self.hparams.batch_size,
drop_last = self.hparams.drop_last,
num_workers = self.hparams.num_workers,
persistent_workers = self.hparams.persistent_workers,
pin_memory = self.hparams.pin_memory,
prefetch_factor = self.hparams.prefetch_factor,
shuffle = self.hparams.shuffle,
generator = self.generator,
)
def val_dataloader(self):
return DataLoader(self.dataset_val,
batch_size = self.hparams.batch_size,
drop_last = self.hparams.drop_last,
num_workers = self.hparams.num_workers,
persistent_workers = self.hparams.persistent_workers,
pin_memory = self.hparams.pin_memory,
prefetch_factor = self.hparams.prefetch_factor,
generator = self.generator,
)
class StanfordUVDataModule(RayFieldAdDataModuleBase):
skyward = "+Z"
def __init__(self,
data_dir : Union[str, Path, None] = None,
obj_names : list[str] = ["bunny"], # empty means all
prng_seed : int = 1337,
step : int = 2,
batch_size : int = 5,
drop_last : bool = False,
num_workers : int = 8,
persistent_workers : bool = True,
pin_memory : int = True,
prefetch_factor : int = 2,
shuffle : bool = True,
val_fraction : float = 0.30,
):
super().__init__()
if not obj_names:
obj_names = stanford_read.list_object_names()
self.save_hyperparameters()
@property
def observation_ids(self) -> list[str]:
return self.hparams.obj_names
def mk_ad_dataset(self) -> common.AutodecoderDataset:
return stanford.AutodecoderSingleViewUVScanDataset(
obj_names = self.hparams.obj_names,
data_path = self.hparams.data_dir,
)
@staticmethod
def get_trimesh_from_uid(obj_name) -> Trimesh:
import mesh_to_sdf
mesh = stanford_read.read_mesh(obj_name)
return mesh_to_sdf.scale_to_unit_sphere(mesh)
@staticmethod
def get_sphere_scan_from_uid(obj_name) -> SingleViewUVScan:
return stanford_read.read_mesh_mesh_sphere_scan(obj_name)
class CosegUVDataModule(RayFieldAdDataModuleBase):
skyward = "+Y"
def __init__(self,
data_dir : Union[str, Path, None] = None,
object_sets : tuple[str] = ["tele-aliens"], # empty means all
prng_seed : int = 1337,
step : int = 2,
batch_size : int = 5,
drop_last : bool = False,
num_workers : int = 8,
persistent_workers : bool = True,
pin_memory : int = True,
prefetch_factor : int = 2,
shuffle : bool = True,
val_fraction : float = 0.30,
):
super().__init__()
if not object_sets:
object_sets = coseg_read.list_object_sets()
object_sets = tuple(object_sets)
self.save_hyperparameters()
@property
def observation_ids(self) -> list[str]:
return coseg_read.list_model_id_strings(self.hparams.object_sets)
def mk_ad_dataset(self) -> common.AutodecoderDataset:
return coseg.AutodecoderSingleViewUVScanDataset(
object_sets = self.hparams.object_sets,
data_path = self.hparams.data_dir,
)
@staticmethod
def get_trimesh_from_uid(string_uid):
raise NotImplementedError
@staticmethod
def get_sphere_scan_from_uid(string_uid) -> SingleViewUVScan:
uid = coseg_read.model_id_string_to_uid(string_uid)
return coseg_read.read_mesh_mesh_sphere_scan(*uid)
def mk_cli(args=None) -> CliInterface:
cli = CliInterface(
module_cls = IField,
datamodule_cls = [StanfordUVDataModule, CosegUVDataModule],
workdir = Path(__file__).parent.resolve(),
experiment_name_prefix = "ifield",
)
cli.trainer_defaults.update(dict(
precision = 16,
min_epochs = 5,
))
@cli.register_pre_training_callback
def populate_autodecoder_z_uids(args: Namespace, config: Munch, module: IField, trainer: pl.Trainer, datamodule: RayFieldAdDataModuleBase, logger: logging.Logger):
module.set_observation_ids(datamodule.observation_ids)
rank = getattr(rank_zero_only, "rank", 0)
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) = }")
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) > 1 = }")
rich.print(f"[rank {rank}] {module.is_conditioned = }")
@cli.register_action(help="Interactive window with direct renderings from the model", args=[
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
("--analytical-normals", dict(action="store_true")),
("--ground-truth", dict(action="store_true")),
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
("--res", dict(type=int, nargs=2, default=(210, 160), help="Rendering resolution")),
("--bg", dict(choices=["map", "white", "black"], default="map")),
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
("--scale", dict(type=int, default=3, help="Rendering scale")),
("--fps", dict(type=int, default=None, help="FPS upper limit")),
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
("--write", dict(type=Path, default=None, help="Where to write a screenshot.")),
])
@torch.no_grad()
def viewer(args: Namespace, config: Munch, model: IField):
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
name = config.experiment_name,
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
res = args.res,
skyward = args.skyward,
scale = args.scale,
mesh_gt_getter = datamodule_cls.get_trimesh_from_uid,
)
viewer.display_mode_shading = args.shading
viewer.display_mode_centroid = args.centroid
viewer.display_mode_spheres = args.spheres
if args.ground_truth: viewer.display_mode_normals = viewer.vizmodes_normals.index("ground_truth")
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
viewer.atom_index_solo = args.solo_atom
viewer.fps_cap = args.fps
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
if args.cam_state is not None:
viewer.cam_state = json.loads(args.cam_state)
if args.write is None:
viewer.run()
else:
assert args.write.suffix == ".png", args.write.name
viewer.render_headless(args.write,
n_frames = 1,
fps = 1,
state_callback = None,
)
@cli.register_action(help="Prerender direct renderings from the model", args=[
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
("uids", dict(type=str, nargs="*")),
("--frames", dict(type=int, default=60, help="Number of per interpolation. Default is 60")),
("--fps", dict(type=int, default=60, help="Default is 60")),
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
("--analytical-normals", dict(action="store_true")),
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
("--res", dict(type=int, nargs=2, default=(240, 240), help="Rendering resolution. Default is 240 240")),
("--bg", dict(choices=["map", "white", "black"], default="map")),
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
])
@torch.no_grad()
def render_video_interpolation(args: Namespace, config: Munch, model: IField, **kw):
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
uids = args.uids or list(model.keys())
assert len(uids) > 1
if not args.uids: uids.append(uids[0])
viewer = ModelViewer(model, uids[0],
name = config.experiment_name,
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
res = args.res,
skyward = args.skyward,
)
if args.cam_state is not None:
viewer.cam_state = json.loads(args.cam_state)
viewer.display_mode_shading = args.shading
viewer.display_mode_centroid = args.centroid
viewer.display_mode_spheres = args.spheres
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
viewer.atom_index_solo = args.solo_atom
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
def state_callback(self: ModelViewer, frame: int):
if frame % args.frames:
self.lambertian_color = (0.8, 0.8, 1.0)
else:
self.lambertian_color = (1.0, 1.0, 1.0)
self.fps = args.frames
idx = frame // args.frames + 1
if idx != len(uids):
self.current_uid = uids[idx]
print(f"Writing video to {str(args.output_path)!r}...")
viewer.render_headless(args.output_path,
n_frames = args.frames * (len(uids)-1) + 1,
fps = args.fps,
state_callback = state_callback,
bitrate = args.bitrate,
)
@cli.register_action(help="Prerender direct renderings from the model", args=[
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
("--frames", dict(type=int, default=180, help="Number of frames. Default is 180")),
("--fps", dict(type=int, default=60, help="Default is 60")),
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
("--analytical-normals", dict(action="store_true")),
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
("--res", dict(type=int, nargs=2, default=(320, 240), help="Rendering resolution. Default is 320 240")),
("--bg", dict(choices=["map", "white", "black"], default="map")),
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
])
@torch.no_grad()
def render_video_spin(args: Namespace, config: Munch, model: IField, **kw):
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
name = config.experiment_name,
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
res = args.res,
skyward = args.skyward,
)
if args.cam_state is not None:
viewer.cam_state = json.loads(args.cam_state)
viewer.display_mode_shading = args.shading
viewer.display_mode_centroid = args.centroid
viewer.display_mode_spheres = args.spheres
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
viewer.atom_index_solo = args.solo_atom
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
cam_rot_x_init = viewer.cam_rot_x
def state_callback(self: ModelViewer, frame: int):
self.cam_rot_x = cam_rot_x_init + 3.14 * (frame / args.frames) * 2
print(f"Writing video to {str(args.output_path)!r}...")
viewer.render_headless(args.output_path,
n_frames = args.frames,
fps = args.fps,
state_callback = state_callback,
bitrate = args.bitrate,
)
@cli.register_action(help="foo", args=[
("fname", dict(type=Path, help="where to write json")),
("-t", "--transpose", dict(action="store_true", help="transpose the output")),
("--single-shape", dict(action="store_true", help="break after first shape")),
("--batch-size", dict(type=int, default=40_000, help="tradeoff between vram usage and efficiency")),
("--n-cd", dict(type=int, default=30_000, help="Number of points to use when computing chamfer distance")),
("--filter-outliers", dict(action="store_true", help="like in PRIF")),
])
@torch.enable_grad()
def compute_scores(args: Namespace, config: Munch, model: IField, **kw):
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
model.eval()
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
def T(array: np.ndarray, **kw) -> torch.Tensor:
if isinstance(array, torch.Tensor): return array
return torch.tensor(array, device=model.device, dtype=model.dtype if isinstance(array, np.floating) else None, **kw)
MEDIAL = model.hparams.output_mode == "medial_sphere"
if not MEDIAL: assert model.hparams.output_mode == "orthogonal_plane"
uids = sorted(model.keys())
if args.single_shape: uids = [uids[0]]
rich.print(f"{datamodule_cls.__name__ = }")
rich.print(f"{len(uids) = }")
# accumulators for IoU and F-Score, CD and COS
# sum reduction:
n = defaultdict(int)
n_gt_hits = defaultdict(int)
n_gt_miss = defaultdict(int)
n_gt_missing = defaultdict(int)
n_outliers = defaultdict(int)
p_mse = defaultdict(int)
s_mse = defaultdict(int)
cossim_med = defaultdict(int) # medial normals
cossim_jac = defaultdict(int) # jacovian normals
TP,FN,FP,TN = [defaultdict(int) for _ in range(4)] # IoU and f-score
# mean reduction:
cd_dist = {} # chamfer distance
cd_cos_med = {} # chamfer medial normals
cd_cos_jac = {} # chamfer jacovian normals
all_metrics = dict(
n=n, n_gt_hits=n_gt_hits, n_gt_miss=n_gt_miss, n_gt_missing=n_gt_missing, p_mse=p_mse,
cossim_jac=cossim_jac,
TP=TP, FN=FN, FP=FP, TN=TN, cd_dist=cd_dist,
cd_cos_jac=cd_cos_jac,
)
if MEDIAL:
all_metrics["s_mse"] = s_mse
all_metrics["cossim_med"] = cossim_med
all_metrics["cd_cos_med"] = cd_cos_med
if args.filter_outliers:
all_metrics["n_outliers"] = n_outliers
t = datetime.now()
for uid in tqdm(uids, desc="Dataset", position=0, leave=True, disable=len(uids)<=1):
sphere_scan_gt = datamodule_cls.get_sphere_scan_from_uid(uid)
z = model[uid].detach()
all_intersections = []
all_medial_normals = []
all_jacobian_normals = []
step = args.batch_size
for i in tqdm(range(0, sphere_scan_gt.hits.shape[0], step), desc=f"Item {uid!r}", position=1, leave=False):
# prepare batch and gt
origins = T(sphere_scan_gt.cam_pos [i:i+step, :], requires_grad = True)
dirs = T(sphere_scan_gt.ray_dirs [i:i+step, :])
gt_hits = T(sphere_scan_gt.hits [i:i+step])
gt_miss = T(sphere_scan_gt.miss [i:i+step])
gt_missing = T(sphere_scan_gt.missing [i:i+step])
gt_points = T(sphere_scan_gt.points [i:i+step, :])
gt_normals = T(sphere_scan_gt.normals [i:i+step, :])
gt_distances = T(sphere_scan_gt.distances[i:i+step])
# forward
if MEDIAL:
(
depths,
silhouettes,
intersections,
medial_normals,
is_intersecting,
sphere_centers,
sphere_radii,
) = model({
"origins" : origins,
"dirs" : dirs,
}, z, intersections_only=False, allow_nans=False)
else:
silhouettes = medial_normals = None
intersections, is_intersecting = model({
"origins" : origins,
"dirs" : dirs,
}, z, normalize_origins = True)
is_intersecting = is_intersecting > 0.5
jac = diff.jacobian(intersections, origins, detach=True)
# outlier removal (PRIF)
if args.filter_outliers:
outliers = jac.norm(dim=-2).norm(dim=-1) > 5
n_outliers[uid] += outliers[is_intersecting].sum().item()
# We count filtered points as misses
is_intersecting &= ~outliers
model.zero_grad()
jacobian_normals = model.compute_normals_from_intersection_origin_jacobian(jac, dirs)
all_intersections .append(intersections .detach()[is_intersecting.detach(), :])
all_medial_normals .append(medial_normals .detach()[is_intersecting.detach(), :]) if MEDIAL else None
all_jacobian_normals.append(jacobian_normals.detach()[is_intersecting.detach(), :])
# accumulate metrics
with torch.no_grad():
n [uid] += dirs.shape[0]
n_gt_hits [uid] += gt_hits.sum().item()
n_gt_miss [uid] += gt_miss.sum().item()
n_gt_missing [uid] += gt_missing.sum().item()
p_mse [uid] += (gt_points [gt_hits, :] - intersections[gt_hits, :]).norm(2, dim=-1).pow(2).sum().item()
if MEDIAL: s_mse [uid] += (gt_distances[gt_miss] - silhouettes [gt_miss] ) .pow(2).sum().item()
if MEDIAL: cossim_med[uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], medial_normals [gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
cossim_jac [uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], jacobian_normals[gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
not_intersecting = ~is_intersecting
TP [uid] += ((gt_hits | gt_missing) & is_intersecting).sum().item() # True Positive
FN [uid] += ((gt_hits | gt_missing) & not_intersecting).sum().item() # False Negative
FP [uid] += (gt_miss & is_intersecting).sum().item() # False Positive
TN [uid] += (gt_miss & not_intersecting).sum().item() # True Negative
all_intersections = torch.cat(all_intersections, dim=0)
all_medial_normals = torch.cat(all_medial_normals, dim=0) if MEDIAL else None
all_jacobian_normals = torch.cat(all_jacobian_normals, dim=0)
hits = sphere_scan_gt.hits # brevity
print()
assert all_intersections.shape[0] >= args.n_cd
idx_cd_pred = torch.randperm(all_intersections.shape[0])[:args.n_cd]
idx_cd_gt = torch.randperm(hits.sum()) [:args.n_cd]
print("cd... ", end="")
tt = datetime.now()
loss_cd, loss_cos_jac = chamfer_distance(
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
x_normals = all_jacobian_normals [None, :, :][:, idx_cd_pred, :].detach(),
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
batch_reduction = "sum", point_reduction = "sum",
)
if MEDIAL: _, loss_cos_med = chamfer_distance(
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
x_normals = all_medial_normals [None, :, :][:, idx_cd_pred, :].detach(),
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
batch_reduction = "sum", point_reduction = "sum",
)
print(datetime.now() - tt)
cd_dist [uid] = loss_cd.item()
cd_cos_med [uid] = loss_cos_med.item() if MEDIAL else None
cd_cos_jac [uid] = loss_cos_jac.item()
print()
model.zero_grad(set_to_none=True)
print("Total time:", datetime.now() - t)
print("Time per item:", (datetime.now() - t) / len(uids)) if len(uids) > 1 else None
sum = lambda *xs: builtins .sum (itertools.chain(*(x.values() for x in xs)))
mean = lambda *xs: statistics.mean (itertools.chain(*(x.values() for x in xs)))
stdev = lambda *xs: statistics.stdev(itertools.chain(*(x.values() for x in xs)))
n_cd = args.n_cd
P = sum(TP)/(sum(TP, FP))
R = sum(TP)/(sum(TP, FN))
print(f"{mean(n) = :11.1f} (rays per object)")
print(f"{mean(n_gt_hits) = :11.1f} (gt rays hitting per object)")
print(f"{mean(n_gt_miss) = :11.1f} (gt rays missing per object)")
print(f"{mean(n_gt_missing) = :11.1f} (gt rays unknown per object)")
print(f"{mean(n_outliers) = :11.1f} (gt rays unknown per object)") if args.filter_outliers else None
print(f"{n_cd = :11.0f} (cd rays per object)")
print(f"{mean(n_gt_hits) / mean(n) = :11.8f} (fraction rays hitting per object)")
print(f"{mean(n_gt_miss) / mean(n) = :11.8f} (fraction rays missing per object)")
print(f"{mean(n_gt_missing)/ mean(n) = :11.8f} (fraction rays unknown per object)")
print(f"{mean(n_outliers) / mean(n) = :11.8f} (fraction rays unknown per object)") if args.filter_outliers else None
print(f"{sum(TP)/sum(n) = :11.8f} (total ray TP)")
print(f"{sum(TN)/sum(n) = :11.8f} (total ray TN)")
print(f"{sum(FP)/sum(n) = :11.8f} (total ray FP)")
print(f"{sum(FN)/sum(n) = :11.8f} (total ray FN)")
print(f"{sum(TP, FN, FP)/sum(n) = :11.8f} (total ray union)")
print(f"{sum(TP)/sum(TP, FN, FP) = :11.8f} (total ray IoU)")
print(f"{sum(TP)/(sum(TP, FP)) = :11.8f} -> P (total ray precision)")
print(f"{sum(TP)/(sum(TP, FN)) = :11.8f} -> R (total ray recall)")
print(f"{2*(P*R)/(P+R) = :11.8f} (total ray F-score)")
print(f"{sum(p_mse)/sum(n_gt_hits) = :11.8f} (mean ray intersection mean squared error)")
print(f"{sum(s_mse)/sum(n_gt_miss) = :11.8f} (mean ray silhoutette mean squared error)")
print(f"{sum(cossim_med)/sum(n_gt_hits) = :11.8f} (mean ray medial reduced cosine similarity)") if MEDIAL else None
print(f"{sum(cossim_jac)/sum(n_gt_hits) = :11.8f} (mean ray analytical reduced cosine similarity)")
print(f"{mean(cd_dist) /n_cd * 1e3 = :11.8f} (mean chamfer distance)")
print(f"{mean(cd_cos_med)/n_cd = :11.8f} (mean chamfer reduced medial cossim distance)") if MEDIAL else None
print(f"{mean(cd_cos_jac)/n_cd = :11.8f} (mean chamfer reduced analytical cossim distance)")
print(f"{stdev(cd_dist) /n_cd * 1e3 = :11.8f} (stdev chamfer distance)") if len(cd_dist) > 1 else None
print(f"{stdev(cd_cos_med)/n_cd = :11.8f} (stdev chamfer reduced medial cossim distance)") if len(cd_cos_med) > 1 and MEDIAL else None
print(f"{stdev(cd_cos_jac)/n_cd = :11.8f} (stdev chamfer reduced analytical cossim distance)") if len(cd_cos_jac) > 1 else None
if args.transpose:
all_metrics, old_metrics = defaultdict(dict), all_metrics
for m, table in old_metrics.items():
for uid, vals in table.items():
all_metrics[uid][m] = vals
all_metrics["_hparams"] = dict(n_cd=args.n_cd)
else:
all_metrics["n_cd"] = args.n_cd
if str(args.fname) == "-":
print("{", ',\n'.join(
f" {json.dumps(k)}: {json.dumps(v)}"
for k, v in all_metrics.items()
), "}", sep="\n")
else:
args.fname.parent.mkdir(parents=True, exist_ok=True)
with args.fname.open("w") as f:
json.dump(all_metrics, f, indent=2)
return cli
if __name__ == "__main__":
mk_cli().run()