#!/usr/bin/env python3
from abc import ABC, abstractmethod
from argparse import Namespace
from collections import defaultdict
from datetime import datetime
from ifield import logging
from ifield.cli import CliInterface
from ifield.data.common.scan import SingleViewUVScan
from ifield.data.coseg import read as coseg_read
from ifield.data.stanford import read as stanford_read
from ifield.datasets import stanford, coseg, common
from ifield.models import intersection_fields
from ifield.utils.operators import diff
from ifield.viewer.ray_field import ModelViewer
from munch import Munch
from pathlib import Path
from pytorch3d.loss.chamfer import chamfer_distance
from pytorch_lightning.utilities import rank_zero_only
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from trimesh import Trimesh
from typing import Union
import builtins
import itertools
import json
import numpy as np
import pytorch_lightning as pl
import rich
import rich.pretty
import statistics
import torch
pl.seed_everything(31337)
torch.set_float32_matmul_precision('medium')


IField = intersection_fields.IntersectionFieldAutoDecoderModel # brevity


class RayFieldAdDataModuleBase(pl.LightningDataModule, ABC):
    @property
    @abstractmethod
    def observation_ids(self) -> list[str]:
        ...

    @abstractmethod
    def mk_ad_dataset(self) -> common.AutodecoderDataset:
        ...

    @staticmethod
    @abstractmethod
    def get_trimesh_from_uid(uid) -> Trimesh:
        ...

    @staticmethod
    @abstractmethod
    def get_sphere_scan_from_uid(uid) -> SingleViewUVScan:
        ...

    def setup(self, stage=None):
        assert stage in ["fit", None] # fit is for train/val, None is for all. "test" not supported ATM

        if not self.hparams.data_dir is None:
            coseg.config.DATA_PATH = self.hparams.data_dir
        step = self.hparams.step # brevity

        dataset = self.mk_ad_dataset()
        n_items_pre_step_mapping = len(dataset)

        if step > 1:
            dataset = common.TransformExtendedDataset(dataset)

        for sx in range(step):
            for sy in range(step):
                def make_slicer(sx, sy, step) -> callable: # the closure is required
                    if step > 1:
                        return lambda t: t[sx::step, sy::step]
                    else:
                        return lambda t: t
                @dataset.map(slicer=make_slicer(sx, sy, step))
                def unpack(sample: tuple[str, SingleViewUVScan], slicer: callable):
                    scan: SingleViewUVScan = sample[1]
                    assert not scan.hits.shape[0] % step, f"{scan.hits.shape[0]=} not divisible by {step=}"
                    assert not scan.hits.shape[1] % step, f"{scan.hits.shape[1]=} not divisible by {step=}"

                    return {
                        "z_uid"     : sample[0],
                        "origins"   : scan.cam_pos,
                        "dirs"      : slicer(scan.ray_dirs),
                        "points"    : slicer(scan.points),
                        "hits"      : slicer(scan.hits),
                        "miss"      : slicer(scan.miss),
                        "normals"   : slicer(scan.normals),
                        "distances" : slicer(scan.distances),
                    }

        # Split each object into train/val with SampleSplit
        n_items = len(dataset)
        n_val   = int(n_items * self.hparams.val_fraction)
        n_train = n_items - n_val
        self.generator = torch.Generator().manual_seed(self.hparams.prng_seed)

        # split the dataset such that all steps are in same part
        assert n_items == n_items_pre_step_mapping * step * step, (n_items, n_items_pre_step_mapping, step)
        indices = [
            i*step*step + sx*step + sy
            for i in torch.randperm(n_items_pre_step_mapping, generator=self.generator).tolist()
            for sx in range(step)
            for sy in range(step)
        ]
        self.dataset_train  = Subset(dataset, sorted(indices[:n_train], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
        self.dataset_val    = Subset(dataset, sorted(indices[n_train:n_train+n_val], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))

        assert len(self.dataset_train) % self.hparams.batch_size == 0
        assert len(self.dataset_val)   % self.hparams.batch_size == 0

    def train_dataloader(self):
        return DataLoader(self.dataset_train,
            batch_size         = self.hparams.batch_size,
            drop_last          = self.hparams.drop_last,
            num_workers        = self.hparams.num_workers,
            persistent_workers = self.hparams.persistent_workers,
            pin_memory         = self.hparams.pin_memory,
            prefetch_factor    = self.hparams.prefetch_factor,
            shuffle            = self.hparams.shuffle,
            generator          = self.generator,
        )

    def val_dataloader(self):
        return DataLoader(self.dataset_val,
            batch_size         = self.hparams.batch_size,
            drop_last          = self.hparams.drop_last,
            num_workers        = self.hparams.num_workers,
            persistent_workers = self.hparams.persistent_workers,
            pin_memory         = self.hparams.pin_memory,
            prefetch_factor    = self.hparams.prefetch_factor,
            generator          = self.generator,
        )


class StanfordUVDataModule(RayFieldAdDataModuleBase):
    skyward = "+Z"
    def __init__(self,
            data_dir           : Union[str, Path, None] = None,
            obj_names          : list[str]              = ["bunny"], # empty means all

            prng_seed          : int                    = 1337,
            step               : int                    = 2,
            batch_size         : int                    = 5,
            drop_last          : bool                   = False,
            num_workers        : int                    = 8,
            persistent_workers : bool                   = True,
            pin_memory         : int                    = True,
            prefetch_factor    : int                    = 2,
            shuffle            : bool                   = True,
            val_fraction       : float                  = 0.30,
            ):
        super().__init__()
        if not obj_names:
            obj_names = stanford_read.list_object_names()
        self.save_hyperparameters()

    @property
    def observation_ids(self) -> list[str]:
        return self.hparams.obj_names

    def mk_ad_dataset(self) -> common.AutodecoderDataset:
        return stanford.AutodecoderSingleViewUVScanDataset(
            obj_names = self.hparams.obj_names,
            data_path = self.hparams.data_dir,
        )

    @staticmethod
    def get_trimesh_from_uid(obj_name) -> Trimesh:
        import mesh_to_sdf
        mesh = stanford_read.read_mesh(obj_name)
        return mesh_to_sdf.scale_to_unit_sphere(mesh)

    @staticmethod
    def get_sphere_scan_from_uid(obj_name) -> SingleViewUVScan:
        return stanford_read.read_mesh_mesh_sphere_scan(obj_name)


class CosegUVDataModule(RayFieldAdDataModuleBase):
    skyward = "+Y"
    def __init__(self,
            data_dir           : Union[str, Path, None] = None,
            object_sets        : tuple[str]             = ["tele-aliens"], # empty means all

            prng_seed          : int                    = 1337,
            step               : int                    = 2,
            batch_size         : int                    = 5,
            drop_last          : bool                   = False,
            num_workers        : int                    = 8,
            persistent_workers : bool                   = True,
            pin_memory         : int                    = True,
            prefetch_factor    : int                    = 2,
            shuffle            : bool                   = True,
            val_fraction       : float                  = 0.30,
            ):
        super().__init__()
        if not object_sets:
            object_sets = coseg_read.list_object_sets()
        object_sets = tuple(object_sets)
        self.save_hyperparameters()

    @property
    def observation_ids(self) -> list[str]:
        return coseg_read.list_model_id_strings(self.hparams.object_sets)

    def mk_ad_dataset(self) -> common.AutodecoderDataset:
        return coseg.AutodecoderSingleViewUVScanDataset(
            object_sets = self.hparams.object_sets,
            data_path   = self.hparams.data_dir,
        )

    @staticmethod
    def get_trimesh_from_uid(string_uid):
        raise NotImplementedError

    @staticmethod
    def get_sphere_scan_from_uid(string_uid) -> SingleViewUVScan:
        uid = coseg_read.model_id_string_to_uid(string_uid)
        return coseg_read.read_mesh_mesh_sphere_scan(*uid)


def mk_cli(args=None) -> CliInterface:
    cli = CliInterface(
        module_cls     = IField,
        datamodule_cls = [StanfordUVDataModule, CosegUVDataModule],
        workdir        = Path(__file__).parent.resolve(),
        experiment_name_prefix = "ifield",
    )
    cli.trainer_defaults.update(dict(
        precision  = 16,
        min_epochs =  5,
    ))

    @cli.register_pre_training_callback
    def populate_autodecoder_z_uids(args: Namespace, config: Munch, module: IField, trainer: pl.Trainer, datamodule: RayFieldAdDataModuleBase, logger: logging.Logger):
        module.set_observation_ids(datamodule.observation_ids)
        rank = getattr(rank_zero_only, "rank", 0)
        rich.print(f"[rank {rank}] {len(datamodule.observation_ids)     = }")
        rich.print(f"[rank {rank}] {len(datamodule.observation_ids) > 1 = }")
        rich.print(f"[rank {rank}] {module.is_conditioned               = }")

    @cli.register_action(help="Interactive window with direct renderings from the model", args=[
        ("--shading",  dict(type=int, default=ModelViewer.vizmodes_shading  .index("lambertian"),              help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
        ("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"),  help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
        ("--spheres",  dict(type=int, default=ModelViewer.vizmodes_spheres  .index(None),                      help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
        ("--analytical-normals",  dict(action="store_true")),
        ("--ground-truth",  dict(action="store_true")),
        ("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
        ("--res",      dict(type=int, nargs=2, default=(210, 160), help="Rendering resolution")),
        ("--bg",       dict(choices=["map", "white", "black"], default="map")),
        ("--skyward",  dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
        ("--scale",    dict(type=int, default=3, help="Rendering scale")),
        ("--fps",      dict(type=int, default=None, help="FPS upper limit")),
        ("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
        ("--write",    dict(type=Path, default=None, help="Where to write a screenshot.")),
    ])
    @torch.no_grad()
    def viewer(args: Namespace, config: Munch, model: IField):
        datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)

        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            model.to("cuda")
        viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
            name           = config.experiment_name,
            screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
            res            = args.res,
            skyward        = args.skyward,
            scale          = args.scale,
            mesh_gt_getter = datamodule_cls.get_trimesh_from_uid,
        )
        viewer.display_mode_shading  = args.shading
        viewer.display_mode_centroid = args.centroid
        viewer.display_mode_spheres  = args.spheres
        if args.ground_truth:       viewer.display_mode_normals = viewer.vizmodes_normals.index("ground_truth")
        if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
        viewer.atom_index_solo       = args.solo_atom
        viewer.fps_cap               = args.fps
        viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
        if args.cam_state is not None:
            viewer.cam_state         = json.loads(args.cam_state)
        if args.write is None:
            viewer.run()
        else:
            assert args.write.suffix == ".png", args.write.name
            viewer.render_headless(args.write,
                n_frames       = 1,
                fps            = 1,
                state_callback = None,
            )

    @cli.register_action(help="Prerender direct renderings from the model", args=[
        ("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
        ("uids",       dict(type=str, nargs="*")),
        ("--frames",   dict(type=int, default=60, help="Number of per interpolation. Default is 60")),
        ("--fps",      dict(type=int, default=60, help="Default is 60")),
        ("--shading",  dict(type=int, default=ModelViewer.vizmodes_shading  .index("lambertian"),             help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
        ("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
        ("--spheres",  dict(type=int, default=ModelViewer.vizmodes_spheres  .index(None),                     help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
        ("--analytical-normals",  dict(action="store_true")),
        ("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
        ("--res",      dict(type=int, nargs=2, default=(240, 240), help="Rendering resolution. Default is 240 240")),
        ("--bg",       dict(choices=["map", "white", "black"], default="map")),
        ("--skyward",  dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
        ("--bitrate",  dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
        ("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
    ])
    @torch.no_grad()
    def render_video_interpolation(args: Namespace, config: Munch, model: IField, **kw):
        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            model.to("cuda")
        uids = args.uids or list(model.keys())
        assert len(uids) > 1
        if not args.uids: uids.append(uids[0])
        viewer = ModelViewer(model, uids[0],
            name           = config.experiment_name,
            screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
            res            = args.res,
            skyward        = args.skyward,
        )
        if args.cam_state is not None:
            viewer.cam_state         = json.loads(args.cam_state)
        viewer.display_mode_shading  = args.shading
        viewer.display_mode_centroid = args.centroid
        viewer.display_mode_spheres  = args.spheres
        if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
        viewer.atom_index_solo       = args.solo_atom
        viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
        def state_callback(self: ModelViewer, frame: int):
            if frame % args.frames:
                self.lambertian_color = (0.8, 0.8, 1.0)
            else:
                self.lambertian_color = (1.0, 1.0, 1.0)
            self.fps = args.frames
            idx = frame // args.frames + 1
            if idx != len(uids):
                self.current_uid = uids[idx]
        print(f"Writing video to {str(args.output_path)!r}...")
        viewer.render_headless(args.output_path,
            n_frames       = args.frames * (len(uids)-1) + 1,
            fps            = args.fps,
            state_callback = state_callback,
            bitrate        = args.bitrate,
        )

    @cli.register_action(help="Prerender direct renderings from the model", args=[
        ("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
        ("--frames",   dict(type=int, default=180, help="Number of frames. Default is 180")),
        ("--fps",      dict(type=int, default=60, help="Default is 60")),
        ("--shading",  dict(type=int, default=ModelViewer.vizmodes_shading  .index("lambertian"),             help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
        ("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
        ("--spheres",  dict(type=int, default=ModelViewer.vizmodes_spheres  .index(None),                     help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
        ("--analytical-normals",  dict(action="store_true")),
        ("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
        ("--res",      dict(type=int, nargs=2, default=(320, 240), help="Rendering resolution. Default is 320 240")),
        ("--bg",       dict(choices=["map", "white", "black"], default="map")),
        ("--skyward",  dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
        ("--bitrate",  dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
        ("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
    ])
    @torch.no_grad()
    def render_video_spin(args: Namespace, config: Munch, model: IField, **kw):
        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            model.to("cuda")
        viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
            name           = config.experiment_name,
            screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
            res            = args.res,
            skyward        = args.skyward,
        )
        if args.cam_state is not None:
            viewer.cam_state         = json.loads(args.cam_state)
        viewer.display_mode_shading  = args.shading
        viewer.display_mode_centroid = args.centroid
        viewer.display_mode_spheres  = args.spheres
        if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
        viewer.atom_index_solo       = args.solo_atom
        viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
        cam_rot_x_init = viewer.cam_rot_x
        def state_callback(self: ModelViewer, frame: int):
            self.cam_rot_x = cam_rot_x_init + 3.14 * (frame / args.frames) * 2
        print(f"Writing video to {str(args.output_path)!r}...")
        viewer.render_headless(args.output_path,
            n_frames       = args.frames,
            fps            = args.fps,
            state_callback = state_callback,
            bitrate        = args.bitrate,
        )

    @cli.register_action(help="foo", args=[
        ("fname",             dict(type=Path, help="where to write json")),
        ("-t", "--transpose", dict(action="store_true", help="transpose the output")),
        ("--single-shape",    dict(action="store_true", help="break after first shape")),
        ("--batch-size",      dict(type=int, default=40_000, help="tradeoff between vram usage and efficiency")),
        ("--n-cd",            dict(type=int, default=30_000, help="Number of points to use when computing chamfer distance")),
        ("--filter-outliers", dict(action="store_true", help="like in PRIF")),
    ])
    @torch.enable_grad()
    def compute_scores(args: Namespace, config: Munch, model: IField, **kw):
        datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
        model.eval()
        if torch.cuda.is_available() and torch.cuda.device_count() > 0:
            model.to("cuda")

        def T(array: np.ndarray, **kw) -> torch.Tensor:
            if isinstance(array, torch.Tensor): return array
            return torch.tensor(array, device=model.device, dtype=model.dtype if isinstance(array, np.floating) else None, **kw)

        MEDIAL = model.hparams.output_mode == "medial_sphere"
        if not MEDIAL: assert model.hparams.output_mode == "orthogonal_plane"


        uids = sorted(model.keys())
        if args.single_shape: uids = [uids[0]]
        rich.print(f"{datamodule_cls.__name__            = }")
        rich.print(f"{len(uids)                          = }")

        # accumulators for IoU and F-Score, CD and COS

        # sum reduction:
        n            = defaultdict(int)
        n_gt_hits    = defaultdict(int)
        n_gt_miss    = defaultdict(int)
        n_gt_missing = defaultdict(int)
        n_outliers   = defaultdict(int)
        p_mse        = defaultdict(int)
        s_mse        = defaultdict(int)
        cossim_med   = defaultdict(int) # medial normals
        cossim_jac   = defaultdict(int) # jacovian normals
        TP,FN,FP,TN  = [defaultdict(int) for _ in range(4)] # IoU and f-score
        # mean reduction:
        cd_dist    = {} # chamfer distance
        cd_cos_med = {} # chamfer medial normals
        cd_cos_jac = {} # chamfer jacovian normals
        all_metrics = dict(
            n=n, n_gt_hits=n_gt_hits, n_gt_miss=n_gt_miss, n_gt_missing=n_gt_missing, p_mse=p_mse,
            cossim_jac=cossim_jac,
            TP=TP, FN=FN, FP=FP, TN=TN, cd_dist=cd_dist,
            cd_cos_jac=cd_cos_jac,
        )
        if MEDIAL:
            all_metrics["s_mse"]      = s_mse
            all_metrics["cossim_med"] = cossim_med
            all_metrics["cd_cos_med"] = cd_cos_med
        if args.filter_outliers:
            all_metrics["n_outliers"] = n_outliers

        t = datetime.now()
        for uid in tqdm(uids, desc="Dataset", position=0, leave=True, disable=len(uids)<=1):
            sphere_scan_gt = datamodule_cls.get_sphere_scan_from_uid(uid)

            z      = model[uid].detach()

            all_intersections    = []
            all_medial_normals   = []
            all_jacobian_normals = []

            step = args.batch_size
            for i in tqdm(range(0, sphere_scan_gt.hits.shape[0], step), desc=f"Item {uid!r}", position=1, leave=False):
                # prepare batch and gt
                origins      = T(sphere_scan_gt.cam_pos  [i:i+step, :], requires_grad = True)
                dirs         = T(sphere_scan_gt.ray_dirs [i:i+step, :])
                gt_hits      = T(sphere_scan_gt.hits     [i:i+step])
                gt_miss      = T(sphere_scan_gt.miss     [i:i+step])
                gt_missing   = T(sphere_scan_gt.missing  [i:i+step])
                gt_points    = T(sphere_scan_gt.points   [i:i+step, :])
                gt_normals   = T(sphere_scan_gt.normals  [i:i+step, :])
                gt_distances = T(sphere_scan_gt.distances[i:i+step])

                # forward
                if MEDIAL:
                    (
                        depths,
                        silhouettes,
                        intersections,
                        medial_normals,
                        is_intersecting,
                        sphere_centers,
                        sphere_radii,
                    ) = model({
                            "origins" : origins,
                            "dirs"    : dirs,
                        }, z, intersections_only=False, allow_nans=False)
                else:
                    silhouettes = medial_normals = None
                    intersections, is_intersecting = model({
                            "origins" : origins,
                            "dirs"    : dirs,
                        }, z, normalize_origins = True)
                    is_intersecting = is_intersecting > 0.5
                jac = diff.jacobian(intersections, origins, detach=True)

                # outlier removal (PRIF)
                if args.filter_outliers:
                    outliers = jac.norm(dim=-2).norm(dim=-1) > 5
                    n_outliers[uid] += outliers[is_intersecting].sum().item()
                    # We count filtered points as misses
                    is_intersecting &= ~outliers

                model.zero_grad()
                jacobian_normals = model.compute_normals_from_intersection_origin_jacobian(jac, dirs)

                all_intersections   .append(intersections   .detach()[is_intersecting.detach(), :])
                all_medial_normals  .append(medial_normals  .detach()[is_intersecting.detach(), :]) if MEDIAL else None
                all_jacobian_normals.append(jacobian_normals.detach()[is_intersecting.detach(), :])

                # accumulate metrics
                with torch.no_grad():
                    n                    [uid] += dirs.shape[0]
                    n_gt_hits            [uid] += gt_hits.sum().item()
                    n_gt_miss            [uid] += gt_miss.sum().item()
                    n_gt_missing         [uid] += gt_missing.sum().item()
                    p_mse                [uid] += (gt_points   [gt_hits, :] - intersections[gt_hits, :]).norm(2, dim=-1).pow(2).sum().item()
                    if MEDIAL: s_mse     [uid] += (gt_distances[gt_miss]    - silhouettes  [gt_miss]   )                .pow(2).sum().item()
                    if MEDIAL: cossim_med[uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], medial_normals  [gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
                    cossim_jac           [uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], jacobian_normals[gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
                    not_intersecting = ~is_intersecting
                    TP                   [uid] += ((gt_hits | gt_missing) &  is_intersecting).sum().item() # True  Positive
                    FN                   [uid] += ((gt_hits | gt_missing) & not_intersecting).sum().item() # False Negative
                    FP                   [uid] += (gt_miss                &  is_intersecting).sum().item() # False Positive
                    TN                   [uid] += (gt_miss                & not_intersecting).sum().item() # True  Negative

            all_intersections    = torch.cat(all_intersections,    dim=0)
            all_medial_normals   = torch.cat(all_medial_normals,   dim=0) if MEDIAL else None
            all_jacobian_normals = torch.cat(all_jacobian_normals, dim=0)

            hits = sphere_scan_gt.hits # brevity
            print()

            assert all_intersections.shape[0] >= args.n_cd
            idx_cd_pred  = torch.randperm(all_intersections.shape[0])[:args.n_cd]
            idx_cd_gt    = torch.randperm(hits.sum())                [:args.n_cd]

            print("cd... ", end="")
            tt = datetime.now()
            loss_cd, loss_cos_jac  = chamfer_distance(
                x         = all_intersections       [None, :,    :][:, idx_cd_pred, :].detach(),
                x_normals = all_jacobian_normals    [None, :,    :][:, idx_cd_pred, :].detach(),
                y         = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt,   :]),
                y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt,   :]),
                batch_reduction = "sum", point_reduction = "sum",
            )
            if MEDIAL: _, loss_cos_med = chamfer_distance(
                x         = all_intersections       [None, :,    :][:, idx_cd_pred, :].detach(),
                x_normals = all_medial_normals      [None, :,    :][:, idx_cd_pred, :].detach(),
                y         = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt,   :]),
                y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt,   :]),
                batch_reduction = "sum", point_reduction = "sum",
            )
            print(datetime.now() - tt)

            cd_dist    [uid] = loss_cd.item()
            cd_cos_med [uid] = loss_cos_med.item() if MEDIAL else None
            cd_cos_jac [uid] = loss_cos_jac.item()

        print()
        model.zero_grad(set_to_none=True)
        print("Total time:",    datetime.now() - t)
        print("Time per item:", (datetime.now() - t) / len(uids)) if len(uids) > 1 else None

        sum   = lambda *xs: builtins  .sum  (itertools.chain(*(x.values() for x in xs)))
        mean  = lambda *xs: statistics.mean (itertools.chain(*(x.values() for x in xs)))
        stdev = lambda *xs: statistics.stdev(itertools.chain(*(x.values() for x in xs)))
        n_cd  = args.n_cd
        P = sum(TP)/(sum(TP, FP))
        R = sum(TP)/(sum(TP, FN))
        print(f"{mean(n)                            = :11.1f}      (rays per object)")
        print(f"{mean(n_gt_hits)                    = :11.1f}      (gt rays hitting per object)")
        print(f"{mean(n_gt_miss)                    = :11.1f}      (gt rays missing per object)")
        print(f"{mean(n_gt_missing)                 = :11.1f}      (gt rays unknown per object)")
        print(f"{mean(n_outliers)                   = :11.1f}      (gt rays unknown per object)") if args.filter_outliers else None
        print(f"{n_cd                               = :11.0f}      (cd rays per object)")
        print(f"{mean(n_gt_hits)   / mean(n)        = :11.8f}      (fraction rays hitting per object)")
        print(f"{mean(n_gt_miss)   / mean(n)        = :11.8f}      (fraction rays missing per object)")
        print(f"{mean(n_gt_missing)/ mean(n)        = :11.8f}      (fraction rays unknown per object)")
        print(f"{mean(n_outliers)  / mean(n)        = :11.8f}      (fraction rays unknown per object)") if args.filter_outliers else None
        print(f"{sum(TP)/sum(n)                     = :11.8f}      (total ray TP)")
        print(f"{sum(TN)/sum(n)                     = :11.8f}      (total ray TN)")
        print(f"{sum(FP)/sum(n)                     = :11.8f}      (total ray FP)")
        print(f"{sum(FN)/sum(n)                     = :11.8f}      (total ray FN)")
        print(f"{sum(TP, FN, FP)/sum(n)             = :11.8f}      (total ray union)")
        print(f"{sum(TP)/sum(TP, FN, FP)            = :11.8f}      (total ray IoU)")
        print(f"{sum(TP)/(sum(TP, FP))              = :11.8f} -> P (total ray precision)")
        print(f"{sum(TP)/(sum(TP, FN))              = :11.8f} -> R (total ray recall)")
        print(f"{2*(P*R)/(P+R)                      = :11.8f}      (total ray F-score)")
        print(f"{sum(p_mse)/sum(n_gt_hits)          = :11.8f}      (mean ray intersection mean squared error)")
        print(f"{sum(s_mse)/sum(n_gt_miss)          = :11.8f}      (mean ray silhoutette  mean squared error)")
        print(f"{sum(cossim_med)/sum(n_gt_hits)     = :11.8f}      (mean ray medial reduced cosine similarity)") if MEDIAL else None
        print(f"{sum(cossim_jac)/sum(n_gt_hits)     = :11.8f}      (mean ray analytical reduced cosine similarity)")
        print(f"{mean(cd_dist)   /n_cd * 1e3        = :11.8f}      (mean chamfer distance)")
        print(f"{mean(cd_cos_med)/n_cd              = :11.8f}      (mean chamfer reduced medial cossim distance)") if MEDIAL else None
        print(f"{mean(cd_cos_jac)/n_cd              = :11.8f}      (mean chamfer reduced analytical cossim distance)")
        print(f"{stdev(cd_dist)   /n_cd * 1e3       = :11.8f}      (stdev chamfer distance)")                           if len(cd_dist) > 1    else None
        print(f"{stdev(cd_cos_med)/n_cd             = :11.8f}      (stdev chamfer reduced medial cossim distance)")     if len(cd_cos_med) > 1 and MEDIAL else None
        print(f"{stdev(cd_cos_jac)/n_cd             = :11.8f}      (stdev chamfer reduced analytical cossim distance)") if len(cd_cos_jac) > 1 else None

        if args.transpose:
            all_metrics, old_metrics = defaultdict(dict), all_metrics
            for m, table in old_metrics.items():
                for uid, vals in table.items():
                    all_metrics[uid][m] = vals
            all_metrics["_hparams"] = dict(n_cd=args.n_cd)
        else:
            all_metrics["n_cd"]  = args.n_cd

        if str(args.fname) == "-":
            print("{", ',\n'.join(
                f"  {json.dumps(k)}: {json.dumps(v)}"
                for k, v in all_metrics.items()
            ), "}", sep="\n")
        else:
            args.fname.parent.mkdir(parents=True, exist_ok=True)
            with args.fname.open("w") as f:
                json.dump(all_metrics, f, indent=2)

    return cli


if __name__ == "__main__":
    mk_cli().run()