#!/usr/bin/env python3 from abc import ABC, abstractmethod from argparse import Namespace from collections import defaultdict from datetime import datetime from ifield import logging from ifield.cli import CliInterface from ifield.data.common.scan import SingleViewUVScan from ifield.data.coseg import read as coseg_read from ifield.data.stanford import read as stanford_read from ifield.datasets import stanford, coseg, common from ifield.models import intersection_fields from ifield.utils.operators import diff from ifield.viewer.ray_field import ModelViewer from munch import Munch from pathlib import Path from pytorch3d.loss.chamfer import chamfer_distance from pytorch_lightning.utilities import rank_zero_only from torch.nn import functional as F from torch.utils.data import DataLoader, Subset from tqdm import tqdm from trimesh import Trimesh from typing import Union import builtins import itertools import json import numpy as np import pytorch_lightning as pl import rich import rich.pretty import statistics import torch pl.seed_everything(31337) torch.set_float32_matmul_precision('medium') IField = intersection_fields.IntersectionFieldAutoDecoderModel # brevity class RayFieldAdDataModuleBase(pl.LightningDataModule, ABC): @property @abstractmethod def observation_ids(self) -> list[str]: ... @abstractmethod def mk_ad_dataset(self) -> common.AutodecoderDataset: ... @staticmethod @abstractmethod def get_trimesh_from_uid(uid) -> Trimesh: ... @staticmethod @abstractmethod def get_sphere_scan_from_uid(uid) -> SingleViewUVScan: ... def setup(self, stage=None): assert stage in ["fit", None] # fit is for train/val, None is for all. "test" not supported ATM if not self.hparams.data_dir is None: coseg.config.DATA_PATH = self.hparams.data_dir step = self.hparams.step # brevity dataset = self.mk_ad_dataset() n_items_pre_step_mapping = len(dataset) if step > 1: dataset = common.TransformExtendedDataset(dataset) for sx in range(step): for sy in range(step): def make_slicer(sx, sy, step) -> callable: # the closure is required if step > 1: return lambda t: t[sx::step, sy::step] else: return lambda t: t @dataset.map(slicer=make_slicer(sx, sy, step)) def unpack(sample: tuple[str, SingleViewUVScan], slicer: callable): scan: SingleViewUVScan = sample[1] assert not scan.hits.shape[0] % step, f"{scan.hits.shape[0]=} not divisible by {step=}" assert not scan.hits.shape[1] % step, f"{scan.hits.shape[1]=} not divisible by {step=}" return { "z_uid" : sample[0], "origins" : scan.cam_pos, "dirs" : slicer(scan.ray_dirs), "points" : slicer(scan.points), "hits" : slicer(scan.hits), "miss" : slicer(scan.miss), "normals" : slicer(scan.normals), "distances" : slicer(scan.distances), } # Split each object into train/val with SampleSplit n_items = len(dataset) n_val = int(n_items * self.hparams.val_fraction) n_train = n_items - n_val self.generator = torch.Generator().manual_seed(self.hparams.prng_seed) # split the dataset such that all steps are in same part assert n_items == n_items_pre_step_mapping * step * step, (n_items, n_items_pre_step_mapping, step) indices = [ i*step*step + sx*step + sy for i in torch.randperm(n_items_pre_step_mapping, generator=self.generator).tolist() for sx in range(step) for sy in range(step) ] self.dataset_train = Subset(dataset, sorted(indices[:n_train], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0])) self.dataset_val = Subset(dataset, sorted(indices[n_train:n_train+n_val], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0])) assert len(self.dataset_train) % self.hparams.batch_size == 0 assert len(self.dataset_val) % self.hparams.batch_size == 0 def train_dataloader(self): return DataLoader(self.dataset_train, batch_size = self.hparams.batch_size, drop_last = self.hparams.drop_last, num_workers = self.hparams.num_workers, persistent_workers = self.hparams.persistent_workers, pin_memory = self.hparams.pin_memory, prefetch_factor = self.hparams.prefetch_factor, shuffle = self.hparams.shuffle, generator = self.generator, ) def val_dataloader(self): return DataLoader(self.dataset_val, batch_size = self.hparams.batch_size, drop_last = self.hparams.drop_last, num_workers = self.hparams.num_workers, persistent_workers = self.hparams.persistent_workers, pin_memory = self.hparams.pin_memory, prefetch_factor = self.hparams.prefetch_factor, generator = self.generator, ) class StanfordUVDataModule(RayFieldAdDataModuleBase): skyward = "+Z" def __init__(self, data_dir : Union[str, Path, None] = None, obj_names : list[str] = ["bunny"], # empty means all prng_seed : int = 1337, step : int = 2, batch_size : int = 5, drop_last : bool = False, num_workers : int = 8, persistent_workers : bool = True, pin_memory : int = True, prefetch_factor : int = 2, shuffle : bool = True, val_fraction : float = 0.30, ): super().__init__() if not obj_names: obj_names = stanford_read.list_object_names() self.save_hyperparameters() @property def observation_ids(self) -> list[str]: return self.hparams.obj_names def mk_ad_dataset(self) -> common.AutodecoderDataset: return stanford.AutodecoderSingleViewUVScanDataset( obj_names = self.hparams.obj_names, data_path = self.hparams.data_dir, ) @staticmethod def get_trimesh_from_uid(obj_name) -> Trimesh: import mesh_to_sdf mesh = stanford_read.read_mesh(obj_name) return mesh_to_sdf.scale_to_unit_sphere(mesh) @staticmethod def get_sphere_scan_from_uid(obj_name) -> SingleViewUVScan: return stanford_read.read_mesh_mesh_sphere_scan(obj_name) class CosegUVDataModule(RayFieldAdDataModuleBase): skyward = "+Y" def __init__(self, data_dir : Union[str, Path, None] = None, object_sets : tuple[str] = ["tele-aliens"], # empty means all prng_seed : int = 1337, step : int = 2, batch_size : int = 5, drop_last : bool = False, num_workers : int = 8, persistent_workers : bool = True, pin_memory : int = True, prefetch_factor : int = 2, shuffle : bool = True, val_fraction : float = 0.30, ): super().__init__() if not object_sets: object_sets = coseg_read.list_object_sets() object_sets = tuple(object_sets) self.save_hyperparameters() @property def observation_ids(self) -> list[str]: return coseg_read.list_model_id_strings(self.hparams.object_sets) def mk_ad_dataset(self) -> common.AutodecoderDataset: return coseg.AutodecoderSingleViewUVScanDataset( object_sets = self.hparams.object_sets, data_path = self.hparams.data_dir, ) @staticmethod def get_trimesh_from_uid(string_uid): raise NotImplementedError @staticmethod def get_sphere_scan_from_uid(string_uid) -> SingleViewUVScan: uid = coseg_read.model_id_string_to_uid(string_uid) return coseg_read.read_mesh_mesh_sphere_scan(*uid) def mk_cli(args=None) -> CliInterface: cli = CliInterface( module_cls = IField, datamodule_cls = [StanfordUVDataModule, CosegUVDataModule], workdir = Path(__file__).parent.resolve(), experiment_name_prefix = "ifield", ) cli.trainer_defaults.update(dict( precision = 16, min_epochs = 5, )) @cli.register_pre_training_callback def populate_autodecoder_z_uids(args: Namespace, config: Munch, module: IField, trainer: pl.Trainer, datamodule: RayFieldAdDataModuleBase, logger: logging.Logger): module.set_observation_ids(datamodule.observation_ids) rank = getattr(rank_zero_only, "rank", 0) rich.print(f"[rank {rank}] {len(datamodule.observation_ids) = }") rich.print(f"[rank {rank}] {len(datamodule.observation_ids) > 1 = }") rich.print(f"[rank {rank}] {module.is_conditioned = }") @cli.register_action(help="Interactive window with direct renderings from the model", args=[ ("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")), ("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")), ("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")), ("--analytical-normals", dict(action="store_true")), ("--ground-truth", dict(action="store_true")), ("--solo-atom",dict(type=int, default=None, help="Rendering mode")), ("--res", dict(type=int, nargs=2, default=(210, 160), help="Rendering resolution")), ("--bg", dict(choices=["map", "white", "black"], default="map")), ("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')), ("--scale", dict(type=int, default=3, help="Rendering scale")), ("--fps", dict(type=int, default=None, help="FPS upper limit")), ("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")), ("--write", dict(type=Path, default=None, help="Where to write a screenshot.")), ]) @torch.no_grad() def viewer(args: Namespace, config: Munch, model: IField): datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config) if torch.cuda.is_available() and torch.cuda.device_count() > 0: model.to("cuda") viewer = ModelViewer(model, start_uid=next(iter(model.keys())), name = config.experiment_name, screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer", res = args.res, skyward = args.skyward, scale = args.scale, mesh_gt_getter = datamodule_cls.get_trimesh_from_uid, ) viewer.display_mode_shading = args.shading viewer.display_mode_centroid = args.centroid viewer.display_mode_spheres = args.spheres if args.ground_truth: viewer.display_mode_normals = viewer.vizmodes_normals.index("ground_truth") if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical") viewer.atom_index_solo = args.solo_atom viewer.fps_cap = args.fps viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg] if args.cam_state is not None: viewer.cam_state = json.loads(args.cam_state) if args.write is None: viewer.run() else: assert args.write.suffix == ".png", args.write.name viewer.render_headless(args.write, n_frames = 1, fps = 1, state_callback = None, ) @cli.register_action(help="Prerender direct renderings from the model", args=[ ("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")), ("uids", dict(type=str, nargs="*")), ("--frames", dict(type=int, default=60, help="Number of per interpolation. Default is 60")), ("--fps", dict(type=int, default=60, help="Default is 60")), ("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")), ("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")), ("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")), ("--analytical-normals", dict(action="store_true")), ("--solo-atom",dict(type=int, default=None, help="Rendering mode")), ("--res", dict(type=int, nargs=2, default=(240, 240), help="Rendering resolution. Default is 240 240")), ("--bg", dict(choices=["map", "white", "black"], default="map")), ("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')), ("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")), ("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")), ]) @torch.no_grad() def render_video_interpolation(args: Namespace, config: Munch, model: IField, **kw): if torch.cuda.is_available() and torch.cuda.device_count() > 0: model.to("cuda") uids = args.uids or list(model.keys()) assert len(uids) > 1 if not args.uids: uids.append(uids[0]) viewer = ModelViewer(model, uids[0], name = config.experiment_name, screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer", res = args.res, skyward = args.skyward, ) if args.cam_state is not None: viewer.cam_state = json.loads(args.cam_state) viewer.display_mode_shading = args.shading viewer.display_mode_centroid = args.centroid viewer.display_mode_spheres = args.spheres if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical") viewer.atom_index_solo = args.solo_atom viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg] def state_callback(self: ModelViewer, frame: int): if frame % args.frames: self.lambertian_color = (0.8, 0.8, 1.0) else: self.lambertian_color = (1.0, 1.0, 1.0) self.fps = args.frames idx = frame // args.frames + 1 if idx != len(uids): self.current_uid = uids[idx] print(f"Writing video to {str(args.output_path)!r}...") viewer.render_headless(args.output_path, n_frames = args.frames * (len(uids)-1) + 1, fps = args.fps, state_callback = state_callback, bitrate = args.bitrate, ) @cli.register_action(help="Prerender direct renderings from the model", args=[ ("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")), ("--frames", dict(type=int, default=180, help="Number of frames. Default is 180")), ("--fps", dict(type=int, default=60, help="Default is 60")), ("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")), ("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")), ("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")), ("--analytical-normals", dict(action="store_true")), ("--solo-atom",dict(type=int, default=None, help="Rendering mode")), ("--res", dict(type=int, nargs=2, default=(320, 240), help="Rendering resolution. Default is 320 240")), ("--bg", dict(choices=["map", "white", "black"], default="map")), ("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')), ("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")), ("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")), ]) @torch.no_grad() def render_video_spin(args: Namespace, config: Munch, model: IField, **kw): if torch.cuda.is_available() and torch.cuda.device_count() > 0: model.to("cuda") viewer = ModelViewer(model, start_uid=next(iter(model.keys())), name = config.experiment_name, screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer", res = args.res, skyward = args.skyward, ) if args.cam_state is not None: viewer.cam_state = json.loads(args.cam_state) viewer.display_mode_shading = args.shading viewer.display_mode_centroid = args.centroid viewer.display_mode_spheres = args.spheres if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical") viewer.atom_index_solo = args.solo_atom viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg] cam_rot_x_init = viewer.cam_rot_x def state_callback(self: ModelViewer, frame: int): self.cam_rot_x = cam_rot_x_init + 3.14 * (frame / args.frames) * 2 print(f"Writing video to {str(args.output_path)!r}...") viewer.render_headless(args.output_path, n_frames = args.frames, fps = args.fps, state_callback = state_callback, bitrate = args.bitrate, ) @cli.register_action(help="foo", args=[ ("fname", dict(type=Path, help="where to write json")), ("-t", "--transpose", dict(action="store_true", help="transpose the output")), ("--single-shape", dict(action="store_true", help="break after first shape")), ("--batch-size", dict(type=int, default=40_000, help="tradeoff between vram usage and efficiency")), ("--n-cd", dict(type=int, default=30_000, help="Number of points to use when computing chamfer distance")), ("--filter-outliers", dict(action="store_true", help="like in PRIF")), ]) @torch.enable_grad() def compute_scores(args: Namespace, config: Munch, model: IField, **kw): datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config) model.eval() if torch.cuda.is_available() and torch.cuda.device_count() > 0: model.to("cuda") def T(array: np.ndarray, **kw) -> torch.Tensor: if isinstance(array, torch.Tensor): return array return torch.tensor(array, device=model.device, dtype=model.dtype if isinstance(array, np.floating) else None, **kw) MEDIAL = model.hparams.output_mode == "medial_sphere" if not MEDIAL: assert model.hparams.output_mode == "orthogonal_plane" uids = sorted(model.keys()) if args.single_shape: uids = [uids[0]] rich.print(f"{datamodule_cls.__name__ = }") rich.print(f"{len(uids) = }") # accumulators for IoU and F-Score, CD and COS # sum reduction: n = defaultdict(int) n_gt_hits = defaultdict(int) n_gt_miss = defaultdict(int) n_gt_missing = defaultdict(int) n_outliers = defaultdict(int) p_mse = defaultdict(int) s_mse = defaultdict(int) cossim_med = defaultdict(int) # medial normals cossim_jac = defaultdict(int) # jacovian normals TP,FN,FP,TN = [defaultdict(int) for _ in range(4)] # IoU and f-score # mean reduction: cd_dist = {} # chamfer distance cd_cos_med = {} # chamfer medial normals cd_cos_jac = {} # chamfer jacovian normals all_metrics = dict( n=n, n_gt_hits=n_gt_hits, n_gt_miss=n_gt_miss, n_gt_missing=n_gt_missing, p_mse=p_mse, cossim_jac=cossim_jac, TP=TP, FN=FN, FP=FP, TN=TN, cd_dist=cd_dist, cd_cos_jac=cd_cos_jac, ) if MEDIAL: all_metrics["s_mse"] = s_mse all_metrics["cossim_med"] = cossim_med all_metrics["cd_cos_med"] = cd_cos_med if args.filter_outliers: all_metrics["n_outliers"] = n_outliers t = datetime.now() for uid in tqdm(uids, desc="Dataset", position=0, leave=True, disable=len(uids)<=1): sphere_scan_gt = datamodule_cls.get_sphere_scan_from_uid(uid) z = model[uid].detach() all_intersections = [] all_medial_normals = [] all_jacobian_normals = [] step = args.batch_size for i in tqdm(range(0, sphere_scan_gt.hits.shape[0], step), desc=f"Item {uid!r}", position=1, leave=False): # prepare batch and gt origins = T(sphere_scan_gt.cam_pos [i:i+step, :], requires_grad = True) dirs = T(sphere_scan_gt.ray_dirs [i:i+step, :]) gt_hits = T(sphere_scan_gt.hits [i:i+step]) gt_miss = T(sphere_scan_gt.miss [i:i+step]) gt_missing = T(sphere_scan_gt.missing [i:i+step]) gt_points = T(sphere_scan_gt.points [i:i+step, :]) gt_normals = T(sphere_scan_gt.normals [i:i+step, :]) gt_distances = T(sphere_scan_gt.distances[i:i+step]) # forward if MEDIAL: ( depths, silhouettes, intersections, medial_normals, is_intersecting, sphere_centers, sphere_radii, ) = model({ "origins" : origins, "dirs" : dirs, }, z, intersections_only=False, allow_nans=False) else: silhouettes = medial_normals = None intersections, is_intersecting = model({ "origins" : origins, "dirs" : dirs, }, z, normalize_origins = True) is_intersecting = is_intersecting > 0.5 jac = diff.jacobian(intersections, origins, detach=True) # outlier removal (PRIF) if args.filter_outliers: outliers = jac.norm(dim=-2).norm(dim=-1) > 5 n_outliers[uid] += outliers[is_intersecting].sum().item() # We count filtered points as misses is_intersecting &= ~outliers model.zero_grad() jacobian_normals = model.compute_normals_from_intersection_origin_jacobian(jac, dirs) all_intersections .append(intersections .detach()[is_intersecting.detach(), :]) all_medial_normals .append(medial_normals .detach()[is_intersecting.detach(), :]) if MEDIAL else None all_jacobian_normals.append(jacobian_normals.detach()[is_intersecting.detach(), :]) # accumulate metrics with torch.no_grad(): n [uid] += dirs.shape[0] n_gt_hits [uid] += gt_hits.sum().item() n_gt_miss [uid] += gt_miss.sum().item() n_gt_missing [uid] += gt_missing.sum().item() p_mse [uid] += (gt_points [gt_hits, :] - intersections[gt_hits, :]).norm(2, dim=-1).pow(2).sum().item() if MEDIAL: s_mse [uid] += (gt_distances[gt_miss] - silhouettes [gt_miss] ) .pow(2).sum().item() if MEDIAL: cossim_med[uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], medial_normals [gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD cossim_jac [uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], jacobian_normals[gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD not_intersecting = ~is_intersecting TP [uid] += ((gt_hits | gt_missing) & is_intersecting).sum().item() # True Positive FN [uid] += ((gt_hits | gt_missing) & not_intersecting).sum().item() # False Negative FP [uid] += (gt_miss & is_intersecting).sum().item() # False Positive TN [uid] += (gt_miss & not_intersecting).sum().item() # True Negative all_intersections = torch.cat(all_intersections, dim=0) all_medial_normals = torch.cat(all_medial_normals, dim=0) if MEDIAL else None all_jacobian_normals = torch.cat(all_jacobian_normals, dim=0) hits = sphere_scan_gt.hits # brevity print() assert all_intersections.shape[0] >= args.n_cd idx_cd_pred = torch.randperm(all_intersections.shape[0])[:args.n_cd] idx_cd_gt = torch.randperm(hits.sum()) [:args.n_cd] print("cd... ", end="") tt = datetime.now() loss_cd, loss_cos_jac = chamfer_distance( x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(), x_normals = all_jacobian_normals [None, :, :][:, idx_cd_pred, :].detach(), y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]), y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]), batch_reduction = "sum", point_reduction = "sum", ) if MEDIAL: _, loss_cos_med = chamfer_distance( x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(), x_normals = all_medial_normals [None, :, :][:, idx_cd_pred, :].detach(), y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]), y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]), batch_reduction = "sum", point_reduction = "sum", ) print(datetime.now() - tt) cd_dist [uid] = loss_cd.item() cd_cos_med [uid] = loss_cos_med.item() if MEDIAL else None cd_cos_jac [uid] = loss_cos_jac.item() print() model.zero_grad(set_to_none=True) print("Total time:", datetime.now() - t) print("Time per item:", (datetime.now() - t) / len(uids)) if len(uids) > 1 else None sum = lambda *xs: builtins .sum (itertools.chain(*(x.values() for x in xs))) mean = lambda *xs: statistics.mean (itertools.chain(*(x.values() for x in xs))) stdev = lambda *xs: statistics.stdev(itertools.chain(*(x.values() for x in xs))) n_cd = args.n_cd P = sum(TP)/(sum(TP, FP)) R = sum(TP)/(sum(TP, FN)) print(f"{mean(n) = :11.1f} (rays per object)") print(f"{mean(n_gt_hits) = :11.1f} (gt rays hitting per object)") print(f"{mean(n_gt_miss) = :11.1f} (gt rays missing per object)") print(f"{mean(n_gt_missing) = :11.1f} (gt rays unknown per object)") print(f"{mean(n_outliers) = :11.1f} (gt rays unknown per object)") if args.filter_outliers else None print(f"{n_cd = :11.0f} (cd rays per object)") print(f"{mean(n_gt_hits) / mean(n) = :11.8f} (fraction rays hitting per object)") print(f"{mean(n_gt_miss) / mean(n) = :11.8f} (fraction rays missing per object)") print(f"{mean(n_gt_missing)/ mean(n) = :11.8f} (fraction rays unknown per object)") print(f"{mean(n_outliers) / mean(n) = :11.8f} (fraction rays unknown per object)") if args.filter_outliers else None print(f"{sum(TP)/sum(n) = :11.8f} (total ray TP)") print(f"{sum(TN)/sum(n) = :11.8f} (total ray TN)") print(f"{sum(FP)/sum(n) = :11.8f} (total ray FP)") print(f"{sum(FN)/sum(n) = :11.8f} (total ray FN)") print(f"{sum(TP, FN, FP)/sum(n) = :11.8f} (total ray union)") print(f"{sum(TP)/sum(TP, FN, FP) = :11.8f} (total ray IoU)") print(f"{sum(TP)/(sum(TP, FP)) = :11.8f} -> P (total ray precision)") print(f"{sum(TP)/(sum(TP, FN)) = :11.8f} -> R (total ray recall)") print(f"{2*(P*R)/(P+R) = :11.8f} (total ray F-score)") print(f"{sum(p_mse)/sum(n_gt_hits) = :11.8f} (mean ray intersection mean squared error)") print(f"{sum(s_mse)/sum(n_gt_miss) = :11.8f} (mean ray silhoutette mean squared error)") print(f"{sum(cossim_med)/sum(n_gt_hits) = :11.8f} (mean ray medial reduced cosine similarity)") if MEDIAL else None print(f"{sum(cossim_jac)/sum(n_gt_hits) = :11.8f} (mean ray analytical reduced cosine similarity)") print(f"{mean(cd_dist) /n_cd * 1e3 = :11.8f} (mean chamfer distance)") print(f"{mean(cd_cos_med)/n_cd = :11.8f} (mean chamfer reduced medial cossim distance)") if MEDIAL else None print(f"{mean(cd_cos_jac)/n_cd = :11.8f} (mean chamfer reduced analytical cossim distance)") print(f"{stdev(cd_dist) /n_cd * 1e3 = :11.8f} (stdev chamfer distance)") if len(cd_dist) > 1 else None print(f"{stdev(cd_cos_med)/n_cd = :11.8f} (stdev chamfer reduced medial cossim distance)") if len(cd_cos_med) > 1 and MEDIAL else None print(f"{stdev(cd_cos_jac)/n_cd = :11.8f} (stdev chamfer reduced analytical cossim distance)") if len(cd_cos_jac) > 1 else None if args.transpose: all_metrics, old_metrics = defaultdict(dict), all_metrics for m, table in old_metrics.items(): for uid, vals in table.items(): all_metrics[uid][m] = vals all_metrics["_hparams"] = dict(n_cd=args.n_cd) else: all_metrics["n_cd"] = args.n_cd if str(args.fname) == "-": print("{", ',\n'.join( f" {json.dumps(k)}: {json.dumps(v)}" for k, v in all_metrics.items() ), "}", sep="\n") else: args.fname.parent.mkdir(parents=True, exist_ok=True) with args.fname.open("w") as f: json.dump(all_metrics, f, indent=2) return cli if __name__ == "__main__": mk_cli().run()