This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

624
experiments/marf.py Executable file
View File

@@ -0,0 +1,624 @@
#!/usr/bin/env python3
from abc import ABC, abstractmethod
from argparse import Namespace
from collections import defaultdict
from datetime import datetime
from ifield import logging
from ifield.cli import CliInterface
from ifield.data.common.scan import SingleViewUVScan
from ifield.data.coseg import read as coseg_read
from ifield.data.stanford import read as stanford_read
from ifield.datasets import stanford, coseg, common
from ifield.models import intersection_fields
from ifield.utils.operators import diff
from ifield.viewer.ray_field import ModelViewer
from munch import Munch
from pathlib import Path
from pytorch3d.loss.chamfer import chamfer_distance
from pytorch_lightning.utilities import rank_zero_only
from torch.nn import functional as F
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from trimesh import Trimesh
from typing import Union
import builtins
import itertools
import json
import numpy as np
import pytorch_lightning as pl
import rich
import rich.pretty
import statistics
import torch
pl.seed_everything(31337)
torch.set_float32_matmul_precision('medium')
IField = intersection_fields.IntersectionFieldAutoDecoderModel # brevity
class RayFieldAdDataModuleBase(pl.LightningDataModule, ABC):
@property
@abstractmethod
def observation_ids(self) -> list[str]:
...
@abstractmethod
def mk_ad_dataset(self) -> common.AutodecoderDataset:
...
@staticmethod
@abstractmethod
def get_trimesh_from_uid(uid) -> Trimesh:
...
@staticmethod
@abstractmethod
def get_sphere_scan_from_uid(uid) -> SingleViewUVScan:
...
def setup(self, stage=None):
assert stage in ["fit", None] # fit is for train/val, None is for all. "test" not supported ATM
if not self.hparams.data_dir is None:
coseg.config.DATA_PATH = self.hparams.data_dir
step = self.hparams.step # brevity
dataset = self.mk_ad_dataset()
n_items_pre_step_mapping = len(dataset)
if step > 1:
dataset = common.TransformExtendedDataset(dataset)
for sx in range(step):
for sy in range(step):
def make_slicer(sx, sy, step) -> callable: # the closure is required
if step > 1:
return lambda t: t[sx::step, sy::step]
else:
return lambda t: t
@dataset.map(slicer=make_slicer(sx, sy, step))
def unpack(sample: tuple[str, SingleViewUVScan], slicer: callable):
scan: SingleViewUVScan = sample[1]
assert not scan.hits.shape[0] % step, f"{scan.hits.shape[0]=} not divisible by {step=}"
assert not scan.hits.shape[1] % step, f"{scan.hits.shape[1]=} not divisible by {step=}"
return {
"z_uid" : sample[0],
"origins" : scan.cam_pos,
"dirs" : slicer(scan.ray_dirs),
"points" : slicer(scan.points),
"hits" : slicer(scan.hits),
"miss" : slicer(scan.miss),
"normals" : slicer(scan.normals),
"distances" : slicer(scan.distances),
}
# Split each object into train/val with SampleSplit
n_items = len(dataset)
n_val = int(n_items * self.hparams.val_fraction)
n_train = n_items - n_val
self.generator = torch.Generator().manual_seed(self.hparams.prng_seed)
# split the dataset such that all steps are in same part
assert n_items == n_items_pre_step_mapping * step * step, (n_items, n_items_pre_step_mapping, step)
indices = [
i*step*step + sx*step + sy
for i in torch.randperm(n_items_pre_step_mapping, generator=self.generator).tolist()
for sx in range(step)
for sy in range(step)
]
self.dataset_train = Subset(dataset, sorted(indices[:n_train], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
self.dataset_val = Subset(dataset, sorted(indices[n_train:n_train+n_val], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
assert len(self.dataset_train) % self.hparams.batch_size == 0
assert len(self.dataset_val) % self.hparams.batch_size == 0
def train_dataloader(self):
return DataLoader(self.dataset_train,
batch_size = self.hparams.batch_size,
drop_last = self.hparams.drop_last,
num_workers = self.hparams.num_workers,
persistent_workers = self.hparams.persistent_workers,
pin_memory = self.hparams.pin_memory,
prefetch_factor = self.hparams.prefetch_factor,
shuffle = self.hparams.shuffle,
generator = self.generator,
)
def val_dataloader(self):
return DataLoader(self.dataset_val,
batch_size = self.hparams.batch_size,
drop_last = self.hparams.drop_last,
num_workers = self.hparams.num_workers,
persistent_workers = self.hparams.persistent_workers,
pin_memory = self.hparams.pin_memory,
prefetch_factor = self.hparams.prefetch_factor,
generator = self.generator,
)
class StanfordUVDataModule(RayFieldAdDataModuleBase):
skyward = "+Z"
def __init__(self,
data_dir : Union[str, Path, None] = None,
obj_names : list[str] = ["bunny"], # empty means all
prng_seed : int = 1337,
step : int = 2,
batch_size : int = 5,
drop_last : bool = False,
num_workers : int = 8,
persistent_workers : bool = True,
pin_memory : int = True,
prefetch_factor : int = 2,
shuffle : bool = True,
val_fraction : float = 0.30,
):
super().__init__()
if not obj_names:
obj_names = stanford_read.list_object_names()
self.save_hyperparameters()
@property
def observation_ids(self) -> list[str]:
return self.hparams.obj_names
def mk_ad_dataset(self) -> common.AutodecoderDataset:
return stanford.AutodecoderSingleViewUVScanDataset(
obj_names = self.hparams.obj_names,
data_path = self.hparams.data_dir,
)
@staticmethod
def get_trimesh_from_uid(obj_name) -> Trimesh:
import mesh_to_sdf
mesh = stanford_read.read_mesh(obj_name)
return mesh_to_sdf.scale_to_unit_sphere(mesh)
@staticmethod
def get_sphere_scan_from_uid(obj_name) -> SingleViewUVScan:
return stanford_read.read_mesh_mesh_sphere_scan(obj_name)
class CosegUVDataModule(RayFieldAdDataModuleBase):
skyward = "+Y"
def __init__(self,
data_dir : Union[str, Path, None] = None,
object_sets : tuple[str] = ["tele-aliens"], # empty means all
prng_seed : int = 1337,
step : int = 2,
batch_size : int = 5,
drop_last : bool = False,
num_workers : int = 8,
persistent_workers : bool = True,
pin_memory : int = True,
prefetch_factor : int = 2,
shuffle : bool = True,
val_fraction : float = 0.30,
):
super().__init__()
if not object_sets:
object_sets = coseg_read.list_object_sets()
object_sets = tuple(object_sets)
self.save_hyperparameters()
@property
def observation_ids(self) -> list[str]:
return coseg_read.list_model_id_strings(self.hparams.object_sets)
def mk_ad_dataset(self) -> common.AutodecoderDataset:
return coseg.AutodecoderSingleViewUVScanDataset(
object_sets = self.hparams.object_sets,
data_path = self.hparams.data_dir,
)
@staticmethod
def get_trimesh_from_uid(string_uid):
raise NotImplementedError
@staticmethod
def get_sphere_scan_from_uid(string_uid) -> SingleViewUVScan:
uid = coseg_read.model_id_string_to_uid(string_uid)
return coseg_read.read_mesh_mesh_sphere_scan(*uid)
def mk_cli(args=None) -> CliInterface:
cli = CliInterface(
module_cls = IField,
datamodule_cls = [StanfordUVDataModule, CosegUVDataModule],
workdir = Path(__file__).parent.resolve(),
experiment_name_prefix = "ifield",
)
cli.trainer_defaults.update(dict(
precision = 16,
min_epochs = 5,
))
@cli.register_pre_training_callback
def populate_autodecoder_z_uids(args: Namespace, config: Munch, module: IField, trainer: pl.Trainer, datamodule: RayFieldAdDataModuleBase, logger: logging.Logger):
module.set_observation_ids(datamodule.observation_ids)
rank = getattr(rank_zero_only, "rank", 0)
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) = }")
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) > 1 = }")
rich.print(f"[rank {rank}] {module.is_conditioned = }")
@cli.register_action(help="Interactive window with direct renderings from the model", args=[
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
("--analytical-normals", dict(action="store_true")),
("--ground-truth", dict(action="store_true")),
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
("--res", dict(type=int, nargs=2, default=(210, 160), help="Rendering resolution")),
("--bg", dict(choices=["map", "white", "black"], default="map")),
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
("--scale", dict(type=int, default=3, help="Rendering scale")),
("--fps", dict(type=int, default=None, help="FPS upper limit")),
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
("--write", dict(type=Path, default=None, help="Where to write a screenshot.")),
])
@torch.no_grad()
def viewer(args: Namespace, config: Munch, model: IField):
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
name = config.experiment_name,
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
res = args.res,
skyward = args.skyward,
scale = args.scale,
mesh_gt_getter = datamodule_cls.get_trimesh_from_uid,
)
viewer.display_mode_shading = args.shading
viewer.display_mode_centroid = args.centroid
viewer.display_mode_spheres = args.spheres
if args.ground_truth: viewer.display_mode_normals = viewer.vizmodes_normals.index("ground_truth")
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
viewer.atom_index_solo = args.solo_atom
viewer.fps_cap = args.fps
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
if args.cam_state is not None:
viewer.cam_state = json.loads(args.cam_state)
if args.write is None:
viewer.run()
else:
assert args.write.suffix == ".png", args.write.name
viewer.render_headless(args.write,
n_frames = 1,
fps = 1,
state_callback = None,
)
@cli.register_action(help="Prerender direct renderings from the model", args=[
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
("uids", dict(type=str, nargs="*")),
("--frames", dict(type=int, default=60, help="Number of per interpolation. Default is 60")),
("--fps", dict(type=int, default=60, help="Default is 60")),
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
("--analytical-normals", dict(action="store_true")),
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
("--res", dict(type=int, nargs=2, default=(240, 240), help="Rendering resolution. Default is 240 240")),
("--bg", dict(choices=["map", "white", "black"], default="map")),
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
])
@torch.no_grad()
def render_video_interpolation(args: Namespace, config: Munch, model: IField, **kw):
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
uids = args.uids or list(model.keys())
assert len(uids) > 1
if not args.uids: uids.append(uids[0])
viewer = ModelViewer(model, uids[0],
name = config.experiment_name,
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
res = args.res,
skyward = args.skyward,
)
if args.cam_state is not None:
viewer.cam_state = json.loads(args.cam_state)
viewer.display_mode_shading = args.shading
viewer.display_mode_centroid = args.centroid
viewer.display_mode_spheres = args.spheres
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
viewer.atom_index_solo = args.solo_atom
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
def state_callback(self: ModelViewer, frame: int):
if frame % args.frames:
self.lambertian_color = (0.8, 0.8, 1.0)
else:
self.lambertian_color = (1.0, 1.0, 1.0)
self.fps = args.frames
idx = frame // args.frames + 1
if idx != len(uids):
self.current_uid = uids[idx]
print(f"Writing video to {str(args.output_path)!r}...")
viewer.render_headless(args.output_path,
n_frames = args.frames * (len(uids)-1) + 1,
fps = args.fps,
state_callback = state_callback,
bitrate = args.bitrate,
)
@cli.register_action(help="Prerender direct renderings from the model", args=[
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
("--frames", dict(type=int, default=180, help="Number of frames. Default is 180")),
("--fps", dict(type=int, default=60, help="Default is 60")),
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
("--analytical-normals", dict(action="store_true")),
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
("--res", dict(type=int, nargs=2, default=(320, 240), help="Rendering resolution. Default is 320 240")),
("--bg", dict(choices=["map", "white", "black"], default="map")),
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
])
@torch.no_grad()
def render_video_spin(args: Namespace, config: Munch, model: IField, **kw):
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
name = config.experiment_name,
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
res = args.res,
skyward = args.skyward,
)
if args.cam_state is not None:
viewer.cam_state = json.loads(args.cam_state)
viewer.display_mode_shading = args.shading
viewer.display_mode_centroid = args.centroid
viewer.display_mode_spheres = args.spheres
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
viewer.atom_index_solo = args.solo_atom
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
cam_rot_x_init = viewer.cam_rot_x
def state_callback(self: ModelViewer, frame: int):
self.cam_rot_x = cam_rot_x_init + 3.14 * (frame / args.frames) * 2
print(f"Writing video to {str(args.output_path)!r}...")
viewer.render_headless(args.output_path,
n_frames = args.frames,
fps = args.fps,
state_callback = state_callback,
bitrate = args.bitrate,
)
@cli.register_action(help="foo", args=[
("fname", dict(type=Path, help="where to write json")),
("-t", "--transpose", dict(action="store_true", help="transpose the output")),
("--single-shape", dict(action="store_true", help="break after first shape")),
("--batch-size", dict(type=int, default=40_000, help="tradeoff between vram usage and efficiency")),
("--n-cd", dict(type=int, default=30_000, help="Number of points to use when computing chamfer distance")),
("--filter-outliers", dict(action="store_true", help="like in PRIF")),
])
@torch.enable_grad()
def compute_scores(args: Namespace, config: Munch, model: IField, **kw):
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
model.eval()
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
model.to("cuda")
def T(array: np.ndarray, **kw) -> torch.Tensor:
if isinstance(array, torch.Tensor): return array
return torch.tensor(array, device=model.device, dtype=model.dtype if isinstance(array, np.floating) else None, **kw)
MEDIAL = model.hparams.output_mode == "medial_sphere"
if not MEDIAL: assert model.hparams.output_mode == "orthogonal_plane"
uids = sorted(model.keys())
if args.single_shape: uids = [uids[0]]
rich.print(f"{datamodule_cls.__name__ = }")
rich.print(f"{len(uids) = }")
# accumulators for IoU and F-Score, CD and COS
# sum reduction:
n = defaultdict(int)
n_gt_hits = defaultdict(int)
n_gt_miss = defaultdict(int)
n_gt_missing = defaultdict(int)
n_outliers = defaultdict(int)
p_mse = defaultdict(int)
s_mse = defaultdict(int)
cossim_med = defaultdict(int) # medial normals
cossim_jac = defaultdict(int) # jacovian normals
TP,FN,FP,TN = [defaultdict(int) for _ in range(4)] # IoU and f-score
# mean reduction:
cd_dist = {} # chamfer distance
cd_cos_med = {} # chamfer medial normals
cd_cos_jac = {} # chamfer jacovian normals
all_metrics = dict(
n=n, n_gt_hits=n_gt_hits, n_gt_miss=n_gt_miss, n_gt_missing=n_gt_missing, p_mse=p_mse,
cossim_jac=cossim_jac,
TP=TP, FN=FN, FP=FP, TN=TN, cd_dist=cd_dist,
cd_cos_jac=cd_cos_jac,
)
if MEDIAL:
all_metrics["s_mse"] = s_mse
all_metrics["cossim_med"] = cossim_med
all_metrics["cd_cos_med"] = cd_cos_med
if args.filter_outliers:
all_metrics["n_outliers"] = n_outliers
t = datetime.now()
for uid in tqdm(uids, desc="Dataset", position=0, leave=True, disable=len(uids)<=1):
sphere_scan_gt = datamodule_cls.get_sphere_scan_from_uid(uid)
z = model[uid].detach()
all_intersections = []
all_medial_normals = []
all_jacobian_normals = []
step = args.batch_size
for i in tqdm(range(0, sphere_scan_gt.hits.shape[0], step), desc=f"Item {uid!r}", position=1, leave=False):
# prepare batch and gt
origins = T(sphere_scan_gt.cam_pos [i:i+step, :], requires_grad = True)
dirs = T(sphere_scan_gt.ray_dirs [i:i+step, :])
gt_hits = T(sphere_scan_gt.hits [i:i+step])
gt_miss = T(sphere_scan_gt.miss [i:i+step])
gt_missing = T(sphere_scan_gt.missing [i:i+step])
gt_points = T(sphere_scan_gt.points [i:i+step, :])
gt_normals = T(sphere_scan_gt.normals [i:i+step, :])
gt_distances = T(sphere_scan_gt.distances[i:i+step])
# forward
if MEDIAL:
(
depths,
silhouettes,
intersections,
medial_normals,
is_intersecting,
sphere_centers,
sphere_radii,
) = model({
"origins" : origins,
"dirs" : dirs,
}, z, intersections_only=False, allow_nans=False)
else:
silhouettes = medial_normals = None
intersections, is_intersecting = model({
"origins" : origins,
"dirs" : dirs,
}, z, normalize_origins = True)
is_intersecting = is_intersecting > 0.5
jac = diff.jacobian(intersections, origins, detach=True)
# outlier removal (PRIF)
if args.filter_outliers:
outliers = jac.norm(dim=-2).norm(dim=-1) > 5
n_outliers[uid] += outliers[is_intersecting].sum().item()
# We count filtered points as misses
is_intersecting &= ~outliers
model.zero_grad()
jacobian_normals = model.compute_normals_from_intersection_origin_jacobian(jac, dirs)
all_intersections .append(intersections .detach()[is_intersecting.detach(), :])
all_medial_normals .append(medial_normals .detach()[is_intersecting.detach(), :]) if MEDIAL else None
all_jacobian_normals.append(jacobian_normals.detach()[is_intersecting.detach(), :])
# accumulate metrics
with torch.no_grad():
n [uid] += dirs.shape[0]
n_gt_hits [uid] += gt_hits.sum().item()
n_gt_miss [uid] += gt_miss.sum().item()
n_gt_missing [uid] += gt_missing.sum().item()
p_mse [uid] += (gt_points [gt_hits, :] - intersections[gt_hits, :]).norm(2, dim=-1).pow(2).sum().item()
if MEDIAL: s_mse [uid] += (gt_distances[gt_miss] - silhouettes [gt_miss] ) .pow(2).sum().item()
if MEDIAL: cossim_med[uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], medial_normals [gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
cossim_jac [uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], jacobian_normals[gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
not_intersecting = ~is_intersecting
TP [uid] += ((gt_hits | gt_missing) & is_intersecting).sum().item() # True Positive
FN [uid] += ((gt_hits | gt_missing) & not_intersecting).sum().item() # False Negative
FP [uid] += (gt_miss & is_intersecting).sum().item() # False Positive
TN [uid] += (gt_miss & not_intersecting).sum().item() # True Negative
all_intersections = torch.cat(all_intersections, dim=0)
all_medial_normals = torch.cat(all_medial_normals, dim=0) if MEDIAL else None
all_jacobian_normals = torch.cat(all_jacobian_normals, dim=0)
hits = sphere_scan_gt.hits # brevity
print()
assert all_intersections.shape[0] >= args.n_cd
idx_cd_pred = torch.randperm(all_intersections.shape[0])[:args.n_cd]
idx_cd_gt = torch.randperm(hits.sum()) [:args.n_cd]
print("cd... ", end="")
tt = datetime.now()
loss_cd, loss_cos_jac = chamfer_distance(
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
x_normals = all_jacobian_normals [None, :, :][:, idx_cd_pred, :].detach(),
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
batch_reduction = "sum", point_reduction = "sum",
)
if MEDIAL: _, loss_cos_med = chamfer_distance(
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
x_normals = all_medial_normals [None, :, :][:, idx_cd_pred, :].detach(),
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
batch_reduction = "sum", point_reduction = "sum",
)
print(datetime.now() - tt)
cd_dist [uid] = loss_cd.item()
cd_cos_med [uid] = loss_cos_med.item() if MEDIAL else None
cd_cos_jac [uid] = loss_cos_jac.item()
print()
model.zero_grad(set_to_none=True)
print("Total time:", datetime.now() - t)
print("Time per item:", (datetime.now() - t) / len(uids)) if len(uids) > 1 else None
sum = lambda *xs: builtins .sum (itertools.chain(*(x.values() for x in xs)))
mean = lambda *xs: statistics.mean (itertools.chain(*(x.values() for x in xs)))
stdev = lambda *xs: statistics.stdev(itertools.chain(*(x.values() for x in xs)))
n_cd = args.n_cd
P = sum(TP)/(sum(TP, FP))
R = sum(TP)/(sum(TP, FN))
print(f"{mean(n) = :11.1f} (rays per object)")
print(f"{mean(n_gt_hits) = :11.1f} (gt rays hitting per object)")
print(f"{mean(n_gt_miss) = :11.1f} (gt rays missing per object)")
print(f"{mean(n_gt_missing) = :11.1f} (gt rays unknown per object)")
print(f"{mean(n_outliers) = :11.1f} (gt rays unknown per object)") if args.filter_outliers else None
print(f"{n_cd = :11.0f} (cd rays per object)")
print(f"{mean(n_gt_hits) / mean(n) = :11.8f} (fraction rays hitting per object)")
print(f"{mean(n_gt_miss) / mean(n) = :11.8f} (fraction rays missing per object)")
print(f"{mean(n_gt_missing)/ mean(n) = :11.8f} (fraction rays unknown per object)")
print(f"{mean(n_outliers) / mean(n) = :11.8f} (fraction rays unknown per object)") if args.filter_outliers else None
print(f"{sum(TP)/sum(n) = :11.8f} (total ray TP)")
print(f"{sum(TN)/sum(n) = :11.8f} (total ray TN)")
print(f"{sum(FP)/sum(n) = :11.8f} (total ray FP)")
print(f"{sum(FN)/sum(n) = :11.8f} (total ray FN)")
print(f"{sum(TP, FN, FP)/sum(n) = :11.8f} (total ray union)")
print(f"{sum(TP)/sum(TP, FN, FP) = :11.8f} (total ray IoU)")
print(f"{sum(TP)/(sum(TP, FP)) = :11.8f} -> P (total ray precision)")
print(f"{sum(TP)/(sum(TP, FN)) = :11.8f} -> R (total ray recall)")
print(f"{2*(P*R)/(P+R) = :11.8f} (total ray F-score)")
print(f"{sum(p_mse)/sum(n_gt_hits) = :11.8f} (mean ray intersection mean squared error)")
print(f"{sum(s_mse)/sum(n_gt_miss) = :11.8f} (mean ray silhoutette mean squared error)")
print(f"{sum(cossim_med)/sum(n_gt_hits) = :11.8f} (mean ray medial reduced cosine similarity)") if MEDIAL else None
print(f"{sum(cossim_jac)/sum(n_gt_hits) = :11.8f} (mean ray analytical reduced cosine similarity)")
print(f"{mean(cd_dist) /n_cd * 1e3 = :11.8f} (mean chamfer distance)")
print(f"{mean(cd_cos_med)/n_cd = :11.8f} (mean chamfer reduced medial cossim distance)") if MEDIAL else None
print(f"{mean(cd_cos_jac)/n_cd = :11.8f} (mean chamfer reduced analytical cossim distance)")
print(f"{stdev(cd_dist) /n_cd * 1e3 = :11.8f} (stdev chamfer distance)") if len(cd_dist) > 1 else None
print(f"{stdev(cd_cos_med)/n_cd = :11.8f} (stdev chamfer reduced medial cossim distance)") if len(cd_cos_med) > 1 and MEDIAL else None
print(f"{stdev(cd_cos_jac)/n_cd = :11.8f} (stdev chamfer reduced analytical cossim distance)") if len(cd_cos_jac) > 1 else None
if args.transpose:
all_metrics, old_metrics = defaultdict(dict), all_metrics
for m, table in old_metrics.items():
for uid, vals in table.items():
all_metrics[uid][m] = vals
all_metrics["_hparams"] = dict(n_cd=args.n_cd)
else:
all_metrics["n_cd"] = args.n_cd
if str(args.fname) == "-":
print("{", ',\n'.join(
f" {json.dumps(k)}: {json.dumps(v)}"
for k, v in all_metrics.items()
), "}", sep="\n")
else:
args.fname.parent.mkdir(parents=True, exist_ok=True)
with args.fname.open("w") as f:
json.dump(all_metrics, f, indent=2)
return cli
if __name__ == "__main__":
mk_cli().run()

263
experiments/marf.yaml.j2 Executable file
View File

@@ -0,0 +1,263 @@
#!/usr/bin/env -S python ./marf.py module
{% do require_defined("select", select, 0, "$SLURM_ARRAY_TASK_ID") %}{# requires jinja2.ext.do #}
{% do require_defined("mode", mode, "single", "ablation", "multi", strict=true, exchaustive=true) %}{# requires jinja2.ext.do #}
{% set counter = itertools.count(start=0, step=1) %}
{% set do_condition = mode == "multi" %}
{% set do_ablation = mode == "ablation" %}
{% set hp_matrix = namespace() %}{# hyper parameter matrix #}
{% set hp_matrix.input_mode = [
"both",
"perp_foot",
"plucker",
] if do_ablation else [ "both" ] %}
{% set hp_matrix.output_mode = ["medial_sphere", "orthogonal_plane"] %}{##}
{% set hp_matrix.output_mode = ["medial_sphere"] %}{##}
{% set hp_matrix.n_atoms = [16, 1, 4, 8, 32, 64] if do_ablation else [16] %}{##}
{% set hp_matrix.normal_coeff = [0.25, 0] if do_ablation else [0.25] %}{##}
{% set hp_matrix.dataset_item = [objname] if objname is defined else (["armadillo", "bunny", "happy_buddha", "dragon", "lucy"] if not do_condition else ["four-legged"]) %}{##}
{% set hp_matrix.test_val_split_frac = [0.7] %}{##}
{% set hp_matrix.lr_coeff = [5] %}{##}
{% set hp_matrix.warmup_epochs = [1] if not do_condition else [0.1] %}{##}
{% set hp_matrix.improve_miss_grads = [True] %}{##}
{% set hp_matrix.normalize_ray_dirs = [True] %}{##}
{% set hp_matrix.intersection_coeff = [2, 0] if do_ablation else [2] %}{##}
{% set hp_matrix.miss_distance_coeff = [1, 0, 5] if do_ablation else [1] %}{##}
{% set hp_matrix.relative_out = [False] %}{##}
{% set hp_matrix.hidden_features = [512] %}{# like deepsdf and prif #}
{% set hp_matrix.hidden_layers = [8] %}{# like deepsdf, nerf, prif #}
{% set hp_matrix.nonlinearity = ["leaky_relu"] %}{##}
{% set hp_matrix.omega = [30] %}{##}
{% set hp_matrix.normalization = ["layernorm"] %}{##}
{% set hp_matrix.dropout_percent = [1] %}{##}
{% set hp_matrix.sphere_grow_reg_coeff = [500, 0, 5000] if do_ablation else [500] %}{##}
{% set hp_matrix.geom_init = [True, False] if do_ablation else [True] %}{##}
{% set hp_matrix.loss_inscription = [50, 0, 250] if do_ablation else [50] %}{##}
{% set hp_matrix.atom_centroid_norm_std_reg_negexp = [0, None] if do_ablation else [0] %}{##}
{% set hp_matrix.curvature_reg_coeff = [0.2] %}{##}
{% set hp_matrix.multi_view_reg_coeff = [1, 2] if do_ablation else [1] %}{##}
{% set hp_matrix.grad_reg = [ "multi_view", "nogradreg" ] if do_ablation else [ "multi_view" ] %}
{#% for hp in cartesian_hparams(hp_matrix) %}{##}
{% for hp in ablation_hparams(hp_matrix, caartesian_keys=["output_mode", "dataset_item", "nonlinearity", "test_val_split_frac"]) %}
{% if hp.output_mode == "orthogonal_plane"%}
{% if hp.normal_coeff == 0 %}{% set hp.normal_coeff = 0.25 %}
{% elif hp.normal_coeff == 0.25 %}{% set hp.normal_coeff = 0 %}
{% endif %}
{% if hp.grad_reg == "multi_view" %}{% set hp.grad_reg = "nogradreg" %}
{% elif hp.grad_reg == "nogradreg" %}{% set hp.grad_reg = "multi_view" %}
{% endif %}
{% endif %}
{# filter bad/uninteresting hparam combos #}
{% if ( hp.nonlinearity != "sine" and hp.omega != 30 )
or ( hp.nonlinearity == "sine" and hp.normalization in ("layernorm", "layernorm_na") )
or ( hp.multi_view_reg_coeff != 1 and "multi_view" not in hp.grad_reg )
or ( "curvature" not in hp.grad_reg and hp.curvature_reg_coeff != 0.2 )
or ( hp.output_mode == "orthogonal_plane" and hp.input_mode != "both" )
or ( hp.output_mode == "orthogonal_plane" and hp.atom_centroid_norm_std_reg_negexp != 0 )
or ( hp.output_mode == "orthogonal_plane" and hp.n_atoms != 16 )
or ( hp.output_mode == "orthogonal_plane" and hp.sphere_grow_reg_coeff != 500 )
or ( hp.output_mode == "orthogonal_plane" and hp.loss_inscription != 50 )
or ( hp.output_mode == "orthogonal_plane" and hp.miss_distance_coeff != 1 )
or ( hp.output_mode == "orthogonal_plane" and hp.test_val_split_frac != 0.7 )
or ( hp.output_mode == "orthogonal_plane" and hp.lr_coeff != 5 )
or ( hp.output_mode == "orthogonal_plane" and not hp.geom_init )
or ( hp.output_mode == "orthogonal_plane" and not hp.intersection_coeff )
%}
{% continue %}{# requires jinja2.ext.loopcontrols #}
{% endif %}
{% set index = next(counter) %}
{% if select is not defined and index > 0 %}---{% endif %}
{% if select is not defined or int(select) == index %}
trainer:
gradient_clip_val : 1.0
max_epochs : 200
min_epochs : 200
log_every_n_steps : 20
{% if not do_condition %}
StanfordUVDataModule:
obj_names : ["{{ hp.dataset_item }}"]
step : 4
batch_size : 8
val_fraction : {{ 1-hp.test_val_split_frac }}
{% else %}{# if do_condition #}
CosegUVDataModule:
object_sets : ["{{ hp.dataset_item }}"]
step : 4
batch_size : 8
val_fraction : {{ 1-hp.test_val_split_frac }}
{% endif %}{# if do_condition #}
logging:
save_dir : logdir
type : tensorboard
project : ifield
{% autoescape false %}
{% do require_defined("experiment_name", experiment_name, "single-shape" if do_condition else "multi-shape", strict=true) %}
{% set input_mode_abbr = hp.input_mode
.replace("plucker", "plkr")
.replace("perp_foot", "prpft")
%}
{% set output_mode_abbr = hp.output_mode
.replace("medial_sphere", "marf")
.replace("orthogonal_plane", "prif")
%}
experiment_name: experiment-{{ "" if experiment_name is not defined else experiment_name }}
{#--#}-{{ hp.dataset_item }}
{#--#}-{{ input_mode_abbr }}2{{ output_mode_abbr }}
{#--#}
{%- if hp.output_mode == "medial_sphere" -%}
{#--#}-{{ hp.n_atoms }}atom
{#--# }-{{ "rel" if hp.relative_out else "norel" }}
{#--# }-{{ "e" if hp.improve_miss_grads else "0" }}sqrt
{#--#}-{{ int(hp.loss_inscription) if hp.loss_inscription else "no" }}xinscr
{#--#}-{{ int(hp.miss_distance_coeff * 10) }}dmiss
{#--#}-{{ "geom" if hp.geom_init else "nogeom" }}
{#--#}{% if "curvature" in hp.grad_reg %}
{#- -#}-{{ int(hp.curvature_reg_coeff*10) }}crv
{#--#}{%- endif -%}
{%- elif hp.output_mode == "orthogonal_plane" -%}
{#--#}
{%- endif -%}
{#--#}-{{ int(hp.intersection_coeff*10) }}chit
{#--#}-{{ int(hp.normal_coeff*100) or "no" }}cnrml
{#--# }-{{ "do" if hp.normalize_ray_dirs else "no" }}raynorm
{#--#}-{{ hp.hidden_layers }}x{{ hp.hidden_features }}fc
{#--#}-{{ hp.nonlinearity or "linear" }}
{#--#}
{%- if hp.nonlinearity == "sine" -%}
{#--#}-{{ hp.omega }}omega
{#--#}
{%- endif -%}
{%- if hp.output_mode == "medial_sphere" -%}
{#--#}-{{ str(hp.atom_centroid_norm_std_reg_negexp).replace(*"-n") if hp.atom_centroid_norm_std_reg_negexp is not none else 'no' }}minatomstdngxp
{#--#}-{{ hp.sphere_grow_reg_coeff }}sphgrow
{#--#}
{%- endif -%}
{#--#}-{{ int(hp.dropout_percent*10) }}mdrop
{#--#}-{{ hp.normalization or "nonorm" }}
{#--#}-{{ hp.grad_reg }}
{#--#}{% if "multi_view" in hp.grad_reg %}
{#- -#}-{{ int(hp.multi_view_reg_coeff*10) }}dmv
{#--#}{%- endif -%}
{#--#}-{{ "concat" if do_condition else "nocond" }}
{#--#}-{{ int(hp.warmup_epochs*100) }}cwu{{ int(hp.lr_coeff*100) }}clr{{ int(hp.test_val_split_frac*100) }}tvs
{#--#}-{{ gen_run_uid(4) }} # select with --Oselect={{ index }}
{#--#}
{##}
{% endautoescape %}
IntersectionFieldAutoDecoderModel:
_extra: # used for easier introspection with jq
dataset_item: {{ hp.dataset_item | to_json}}
dataset_test_val_frac: {{ hp.test_val_split_frac }}
select: {{ index }}
input_mode : {{ hp.input_mode }} # in {plucker, perp_foot, both}
output_mode : {{ hp.output_mode }} # in {medial_sphere, orthogonal_plane}
#latent_features : 256 # int
#latent_features : 128 # int
latent_features : 16 # int
hidden_features : {{ hp.hidden_features }} # int
hidden_layers : {{ hp.hidden_layers }} # int
improve_miss_grads : {{ bool(hp.improve_miss_grads) | to_json }}
normalize_ray_dirs : {{ bool(hp.normalize_ray_dirs) | to_json }}
loss_intersection : {{ hp.intersection_coeff }}
loss_intersection_l2 : 0
loss_intersection_proj : 0
loss_intersection_proj_l2 : 0
loss_normal_cossim : {{ hp.normal_coeff }} * EaseSin(85, 15)
loss_normal_euclid : 0
loss_normal_cossim_proj : 0
loss_normal_euclid_proj : 0
{% if "multi_view" in hp.grad_reg %}
loss_multi_view_reg : 0.1 * {{ hp.multi_view_reg_coeff }} * Linear(50)
{% else %}
loss_multi_view_reg : 0
{% endif %}
{% if hp.output_mode == "orthogonal_plane" %}
loss_hit_cross_entropy : 1
{% elif hp.output_mode == "medial_sphere" %}
loss_hit_nodistance_l1 : 0
loss_hit_nodistance_l2 : 100 * {{ hp.miss_distance_coeff }}
loss_miss_distance_l1 : 0
loss_miss_distance_l2 : 10 * {{ hp.miss_distance_coeff }}
loss_inscription_hits : {{ 0.4 * hp.loss_inscription }}
loss_inscription_miss : 0
loss_inscription_hits_l2 : 0
loss_inscription_miss_l2 : {{ 6 * hp.loss_inscription }}
loss_sphere_grow_reg : 1e-6 * {{ hp.sphere_grow_reg_coeff }} # constant
loss_atom_centroid_norm_std_reg: (0.09*(1-Linear(40)) + 0.01) * {{ 10**(-hp.atom_centroid_norm_std_reg_negexp) if hp.atom_centroid_norm_std_reg_negexp is not none else 0 }}
{% else %}{#endif hp.output_mode == "medial_sphere" #}
THIS IS INVALID YAML
{% endif %}
loss_embedding_norm : 0.01**2 * Linear(30, 0.1)
opt_learning_rate : {{ hp.lr_coeff }} * 10**(-4-0.5*EaseSin(170, 30)) # layernorm
opt_warmup : {{ hp.warmup_epochs }}
opt_weight_decay : 5e-6 # float
{% if hp.output_mode == "medial_sphere" %}
# MedialAtomNet:
n_atoms : {{ hp.n_atoms }} # int
{% if hp.geom_init %}
final_init_wrr: [0.05, 0.6, 0.1]
{% else %}
final_init_wrr: null
{% endif %}
{% endif %}
# FCBlock:
normalization : {{ hp.normalization or "null" }} # in {null, layernorm, layernorm_na, weightnorm}
nonlinearity : {{ hp.nonlinearity or "null" }} # in {null, relu, leaky_relu, silu, softplus, elu, selu, sine, sigmoid, tanh }
{% set middle = 1 + hp.hidden_layers // 2 + (hp.hidden_layers % 2) %}{##}
concat_skipped_layers : [{{ middle }}, -1]
{% if do_condition %}
concat_conditioned_layers : [0, {{ middle }}]
{% else %}
concat_conditioned_layers : []
{% endif %}
# FCLayer:
negative_slope : 0.01 # float
omega_0 : {{ hp.omega }} # float
residual_mode : null # in {null, identity}
{% endif %}{# -Oselect #}
{% endfor %}
{% set index = next(counter) %}
# number of possible -Oselect: {{ index }}, from 0 to {{ index-1 }}
# local: for select in {0..{{ index-1 }}}; do python ... -Omode={{ mode }} -Oselect=$select ... ; done
# local: for select in {0..{{ index-1 }}}; do python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=$select -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu ; done
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python ... -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID ...
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu --devices -1 --strategy ddp

849
experiments/summary.py Executable file
View File

@@ -0,0 +1,849 @@
#!/usr/bin/env python
from concurrent.futures import ThreadPoolExecutor, Future, ProcessPoolExecutor
from functools import partial
from more_itertools import first, last, tail
from munch import Munch, DefaultMunch, munchify, unmunchify
from pathlib import Path
from statistics import mean, StatisticsError
from mpl_toolkits.axes_grid1 import make_axes_locatable
from typing import Iterable, Optional, Literal
from math import isnan
import json
import stat
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import os, os.path
import re
import shlex
import time
import itertools
import shutil
import subprocess
import sys
import traceback
import typer
import warnings
import yaml
import tempfile
EXPERIMENTS = Path(__file__).resolve()
LOGDIR = EXPERIMENTS / "logdir"
TENSORBOARD = LOGDIR / "tensorboard"
SLURM_LOGS = LOGDIR / "slurm_logs"
CACHED_SUMMARIES = LOGDIR / "cached_summaries"
COMPUTED_SCORES = LOGDIR / "computed_scores"
MISSING = object()
class SafeLoaderIgnoreUnknown(yaml.SafeLoader):
def ignore_unknown(self, node):
return None
SafeLoaderIgnoreUnknown.add_constructor(None, SafeLoaderIgnoreUnknown.ignore_unknown)
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
parts = (
part.lower()
for part in re.split(r'(?=[A-Z])', text)
if part
)
if join_abbreviations: # this operation is not reversible
parts = list(parts)
if len(parts) > 1:
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
if len(a) == len(b) == 1:
parts[i] = parts[i] + parts.pop(i+1)
return sep.join(parts)
def flatten_dict(data: dict, key_mapper: callable = lambda x: x) -> dict:
if not any(isinstance(val, dict) for val in data.values()):
return data
else:
return {
k: v
for k, v in data.items()
if not isinstance(v, dict)
} | {
f"{key_mapper(p)}/{k}":v
for p,d in data.items()
if isinstance(d, dict)
for k,v in d.items()
}
def parse_jsonl(data: str) -> Iterable[dict]:
yield from map(json.loads, (line for line in data.splitlines() if line.strip()))
def read_jsonl(path: Path) -> Iterable[dict]:
with path.open("r") as f:
data = f.read()
yield from parse_jsonl(data)
def get_experiment_paths(filter: str | None, assert_dumped = False) -> Iterable[Path]:
for path in TENSORBOARD.iterdir():
if filter is not None and not re.search(filter, path.name): continue
if not path.is_dir(): continue
if not (path / "hparams.yaml").is_file():
warnings.warn(f"Missing hparams: {path}")
continue
if not any(path.glob("events.out.tfevents.*")):
warnings.warn(f"Missing tfevents: {path}")
continue
if __debug__ and assert_dumped:
assert (path / "scalars/epoch.json").is_file(), path
assert (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json").is_file(), path
assert (path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json").is_file(), path
yield path
def dump_pl_tensorboard_hparams(experiment: Path):
with (experiment / "hparams.yaml").open() as f:
hparams = yaml.load(f, Loader=SafeLoaderIgnoreUnknown)
shebang = None
with (experiment / "config.yaml").open("w") as f:
raw_yaml = hparams.get('_pickled_cli_args', {}).get('_raw_yaml', "").replace("\n\r", "\n")
if raw_yaml.startswith("#!"): # preserve shebang
shebang, _, raw_yaml = raw_yaml.partition("\n")
f.write(f"{shebang}\n")
f.write(f"# {' '.join(map(shlex.quote, hparams.get('_pickled_cli_args', {}).get('sys_argv', ['None'])))}\n\n")
f.write(raw_yaml)
if shebang is not None:
os.chmod(experiment / "config.yaml", (experiment / "config.yaml").stat().st_mode | stat.S_IXUSR)
print(experiment / "config.yaml", "written!", file=sys.stderr)
with (experiment / "environ.yaml").open("w") as f:
yaml.safe_dump(hparams.get('_pickled_cli_args', {}).get('host', {}).get('environ'), f)
print(experiment / "environ.yaml", "written!", file=sys.stderr)
with (experiment / "repo.patch").open("w") as f:
f.write(hparams.get('_pickled_cli_args', {}).get('host', {}).get('vcs', "None"))
print(experiment / "repo.patch", "written!", file=sys.stderr)
def dump_simple_tf_events_to_jsonl(output_dir: Path, *tf_files: Path):
from google.protobuf.json_format import MessageToDict
import tensorboard.backend.event_processing.event_accumulator
s, l = {}, [] # reused sentinels
#resource.setrlimit(resource.RLIMIT_NOFILE, (2**16,-1))
file_handles = {}
try:
for tffile in tf_files:
loader = tensorboard.backend.event_processing.event_file_loader.LegacyEventFileLoader(str(tffile))
for event in loader.Load():
for summary in MessageToDict(event).get("summary", s).get("value", l):
if "simpleValue" in summary:
tag = summary["tag"]
if tag not in file_handles:
fname = output_dir / f"{tag}.json"
print(f"Opening {str(fname)!r}...", file=sys.stderr)
fname.parent.mkdir(parents=True, exist_ok=True)
file_handles[tag] = fname.open("w") # ("a")
val = summary["simpleValue"]
data = json.dumps({
"step" : event.step,
"value" : float(val) if isinstance(val, str) else val,
"wall_time" : event.wall_time,
})
file_handles[tag].write(f"{data}\n")
finally:
if file_handles:
print("Closing json files...", file=sys.stderr)
for k, v in file_handles.items():
v.close()
NO_FILTER = {
"__uid",
"_minutes",
"_epochs",
"_hp_nonlinearity",
"_val_uloss_intersection",
"_val_uloss_normal_cossim",
"_val_uloss_intersection",
}
def filter_jsonl_columns(data: Iterable[dict | None], no_filter=NO_FILTER) -> list[dict]:
def merge_siren_omega(data: dict) -> dict:
return {
key: (
f"{val}-{data.get('hp_omega_0', 'ERROR')}"
if (key.removeprefix("_"), val) == ("hp_nonlinearity", "sine") else
val
)
for key, val in data.items()
if key != "hp_omega_0"
}
def remove_uninteresting_cols(rows: list[dict]) -> Iterable[dict]:
unique_vals = {}
def register_val(key, val):
unique_vals.setdefault(key, set()).add(repr(val))
return val
whitelisted = {
key
for row in rows
for key, val in row.items()
if register_val(key, val) and val not in ("None", "0", "0.0")
}
for key in unique_vals:
for row in rows:
if key not in row:
unique_vals[key].add(MISSING)
for key, vals in unique_vals.items():
if key not in whitelisted: continue
if len(vals) == 1:
whitelisted.remove(key)
whitelisted.update(no_filter)
yield from (
{
key: val
for key, val in row.items()
if key in whitelisted
}
for row in rows
)
def pessemize_types(rows: list[dict]) -> Iterable[dict]:
types = {}
order = (str, float, int, bool, tuple, type(None))
for row in rows:
for key, val in row.items():
if isinstance(val, list): val = tuple(val)
assert type(val) in order, (type(val), val)
index = order.index(type(val))
types[key] = min(types.get(key, 999), index)
yield from (
{
key: order[types[key]](val) if val is not None else None
for key, val in row.items()
}
for row in rows
)
data = (row for row in data if row is not None)
data = map(partial(flatten_dict, key_mapper=camel_to_snake_case), data)
data = map(merge_siren_omega, data)
data = remove_uninteresting_cols(list(data))
data = pessemize_types(list(data))
return data
PlotMode = Literal["stackplot", "lineplot"]
def plot_losses(experiments: list[Path], mode: PlotMode, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force=True):
def get_losses(experiment: Path, training: bool = True, unscaled: bool = False) -> Iterable[Path]:
if not training and unscaled:
return experiment.glob("scalars/*.validation_step/unscaled_loss_*.json")
elif not training and not unscaled:
return experiment.glob("scalars/*.validation_step/loss_*.json")
elif training and unscaled:
return experiment.glob("scalars/*.training_step/unscaled_loss_*.json")
elif training and not unscaled:
return experiment.glob("scalars/*.training_step/loss_*.json")
print("Mapping colors...")
configurations = [
dict(unscaled=unscaled, training=training),
] if not write else [
dict(unscaled=False, training=False),
dict(unscaled=False, training=True),
dict(unscaled=True, training=False),
dict(unscaled=True, training=True),
]
legends = set(
f"""{
loss.parent.name.split(".", 1)[0]
}.{
loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
}"""
for experiment in experiments
for kw in configurations
for loss in get_losses(experiment, **kw)
)
colormap = dict(zip(
sorted(legends),
itertools.cycle(mcolors.TABLEAU_COLORS),
))
def mkplot(experiment: Path, training: bool = True, unscaled: bool = False) -> tuple[bool, str]:
label = f"{'unscaled' if unscaled else 'scaled'} {'training' if training else 'validation'}"
if write:
old_savefig_fname = experiment / f"{label.replace(' ', '-')}-{mode}.png"
savefig_fname = experiment / "plots" / f"{label.replace(' ', '-')}-{mode}.png"
savefig_fname.parent.mkdir(exist_ok=True, parents=True)
if old_savefig_fname.is_file():
old_savefig_fname.rename(savefig_fname)
if savefig_fname.is_file() and not force:
return True, "savefig_fname already exists"
# Get and sort data
losses = {}
for loss in get_losses(experiment, training=training, unscaled=unscaled):
model = loss.parent.name.split(".", 1)[0]
name = loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
losses[f"{model}.{name}"] = (loss, list(read_jsonl(loss)))
losses = dict(sorted(losses.items())) # sort keys
if not losses:
return True, "no losses"
# unwrap
steps = [i["step"] for i in first(losses.values())[1]]
values = [
[i["value"] if not isnan(i["value"]) else 0 for i in data]
for name, (scalar, data) in losses.items()
]
# normalize
if mode == "stackplot":
totals = list(map(sum, zip(*values)))
values = [
[i / t for i, t in zip(data, totals)]
for data in values
]
print(experiment.name, label)
fig, ax = plt.subplots(figsize=(16, 12))
if mode == "stackplot":
ax.stackplot(steps, values,
colors = list(map(colormap.__getitem__, losses.keys())),
labels = list(
label.split(".", 1)[1].removeprefix("loss_")
for label in losses.keys()
),
)
ax.set_xlim(0, steps[-1])
ax.set_ylim(0, 1)
ax.invert_yaxis()
elif mode == "lineplot":
for data, color, label in zip(
values,
map(colormap.__getitem__, losses.keys()),
list(losses.keys()),
):
ax.plot(steps, data,
color = color,
label = label,
)
ax.set_xlim(0, steps[-1])
else:
raise ValueError(f"{mode=}")
ax.legend()
ax.set_title(f"{label} loss\n{experiment.name}")
ax.set_xlabel("Step")
ax.set_ylabel("loss%")
if mode == "stackplot":
ax2 = make_axes_locatable(ax).append_axes("bottom", 0.8, pad=0.05, sharex=ax)
ax2.stackplot( steps, totals )
for tl in ax.get_xticklabels(): tl.set_visible(False)
fig.tight_layout()
if write:
fig.savefig(savefig_fname, dpi=300)
print(savefig_fname)
plt.close(fig)
return False, None
print("Plotting...")
if write:
matplotlib.use('agg') # fixes "WARNING: QApplication was not created in the main() thread."
any_error = False
if write:
with ThreadPoolExecutor(max_workers=None) as pool:
futures = [
(experiment, pool.submit(mkplot, experiment, **kw))
for experiment in experiments
for kw in configurations
]
else:
def mkfuture(item):
f = Future()
f.set_result(item)
return f
futures = [
(experiment, mkfuture(mkplot(experiment, **kw)))
for experiment in experiments
for kw in configurations
]
for experiment, future in futures:
try:
err, msg = future.result()
except Exception:
traceback.print_exc(file=sys.stderr)
any_error = True
continue
if err:
print(f"{msg}: {experiment.name}")
any_error = True
continue
if not any_error and not write: # show in main thread
plt.show()
elif not write:
print("There were errors, will not show figure...", file=sys.stderr)
# =========
app = typer.Typer(no_args_is_help=True, add_completion=False)
@app.command(help="Dump simple tensorboard events to json and extract some pytorch lightning hparams")
def tf_dump(tfevent_files: list[Path], j: int = typer.Option(1, "-j"), force: bool = False):
# expand to all tfevents files (there may be more than one)
tfevent_files = sorted(set([
tffile
for tffile in tfevent_files
if tffile.name.startswith("events.out.tfevents.")
] + [
tffile
for experiment_dir in tfevent_files
if experiment_dir.is_dir()
for tffile in experiment_dir.glob("events.out.tfevents.*")
] + [
tffile
for hparam_file in tfevent_files
if hparam_file.name in ("hparams.yaml", "config.yaml")
for tffile in hparam_file.parent.glob("events.out.tfevents.*")
]))
# filter already dumped
if not force:
tfevent_files = [
tffile
for tffile in tfevent_files
if not (
(tffile.parent / "scalars/epoch.json").is_file()
and
tffile.stat().st_mtime < (tffile.parent / "scalars/epoch.json").stat().st_mtime
)
]
if not tfevent_files:
raise typer.BadParameter("Nothing to be done, consider --force")
jobs = {}
for tffile in tfevent_files:
if not tffile.is_file():
print("ERROR: file not found:", tffile, file=sys.stderr)
continue
output_dir = tffile.parent / "scalars"
jobs.setdefault(output_dir, []).append(tffile)
with ProcessPoolExecutor() as p:
for experiment in set(tffile.parent for tffile in tfevent_files):
p.submit(dump_pl_tensorboard_hparams, experiment)
for output_dir, tffiles in jobs.items():
p.submit(dump_simple_tf_events_to_jsonl, output_dir, *tffiles)
@app.command(help="Propose experiment regexes")
def propose(cmd: str = typer.Argument("summary"), null: bool = False):
def get():
for i in TENSORBOARD.iterdir():
if not i.is_dir(): continue
if not (i / "hparams.yaml").is_file(): continue
prefix, name, *hparams, year, month, day, hhmm, uid = i.name.split("-")
yield f"{name}.*-{year}-{month}-{day}"
proposals = sorted(set(get()), key=lambda x: x.split(".*-", 1)[1])
print("\n".join(
f"{'>/dev/null ' if null else ''}{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
for i in proposals
))
@app.command("list", help="List used experiment regexes")
def list_cached_summaries(cmd: str = typer.Argument("summary")):
if not CACHED_SUMMARIES.is_dir():
cached = []
else:
cached = [
i.name.removesuffix(".jsonl")
for i in CACHED_SUMMARIES.iterdir()
if i.suffix == ".jsonl"
if i.is_file() and i.stat().st_size
]
def order(key: str) -> list[str]:
return re.sub(r'[^0-9\-]', '', key.split(".*")[-1]).strip("-").split("-") + [key]
print("\n".join(
f"{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
for i in sorted(cached, key=order)
))
@app.command(help="Precompute the summary of a experiment regex")
def compute_summary(filter: str, force: bool = False, dump: bool = False, no_cache: bool = False):
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
if cache.is_file() and cache.stat().st_size:
if not force:
raise FileExistsError(cache)
def mk_summary(path: Path) -> dict | None:
cache = path / "train_summary.json"
if cache.is_file() and cache.stat().st_size and cache.stat().st_mtime > (path/"scalars/epoch.json").stat().st_mtime:
with cache.open() as f:
return json.load(f)
else:
with (path / "hparams.yaml").open() as f:
hparams = munchify(yaml.load(f, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
config = hparams._pickled_cli_args._raw_yaml
config = munchify(yaml.load(config, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
try:
train_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json"))
val_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json"))
except:
traceback.print_exc(file=sys.stderr)
return None
out = Munch()
out.uid = path.name.rsplit("-", 1)[-1]
out.name = path.name
out.date = "-".join(path.name.split("-")[-5:-1])
out.epochs = int(last(read_jsonl(path / "scalars/epoch.json"))["value"])
out.steps = val_loss[-1]["step"]
out.gpu = hparams._pickled_cli_args.host.gpus[1][1]
if val_loss[-1]["wall_time"] - val_loss[0]["wall_time"] > 0:
out.batches_per_second = val_loss[-1]["step"] / (val_loss[-1]["wall_time"] - val_loss[0]["wall_time"])
else:
out.batches_per_second = 0
out.minutes = (val_loss[-1]["wall_time"] - train_loss[0]["wall_time"]) / 60
if (path / "scalars/PsutilMonitor/gpu.00.memory.used.json").is_file():
max(i["value"] for i in read_jsonl(path / "scalars/PsutilMonitor/gpu.00.memory.used.json"))
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step").glob("*.json"):
if not metric_path.is_file() or not metric_path.stat().st_size: continue
metric_name = metric_path.name.removesuffix(".json")
metric_data = read_jsonl(metric_path)
try:
out[f"val_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
except StatisticsError:
out[f"val_{metric_name}"] = float('nan')
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.training_step").glob("*.json"):
if not any(i in metric_path.name for i in ("miss_radius_grad", "sphere_center_grad", "loss_tangential_reg", "multi_view")): continue
if not metric_path.is_file() or not metric_path.stat().st_size: continue
metric_name = metric_path.name.removesuffix(".json")
metric_data = read_jsonl(metric_path)
try:
out[f"train_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
except StatisticsError:
out[f"train_{metric_name}"] = float('nan')
out.hostname = hparams._pickled_cli_args.host.hostname
for key, val in config.IntersectionFieldAutoDecoderModel.items():
if isinstance(val, dict):
out.update({f"hp_{key}_{k}": v for k, v in val.items()})
elif isinstance(val, float | int | str | bool | None):
out[f"hp_{key}"] = val
with cache.open("w") as f:
json.dump(unmunchify(out), f)
return dict(out)
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
if not experiments:
raise typer.BadParameter("No matching experiment")
if dump:
try:
tf_dump(experiments) # force=force_dump)
except typer.BadParameter:
pass
# does literally nothing, thanks GIL
with ThreadPoolExecutor() as p:
results = list(p.map(mk_summary, experiments))
if any(result is None for result in results):
if all(result is None for result in results):
print("No summary succeeded", file=sys.stderr)
raise typer.Exit(exit_code=1)
warnings.warn("Some summaries failed:\n" + "\n".join(
str(experiment)
for result, experiment in zip(results, experiments)
if result is None
))
summaries = "\n".join( map(json.dumps, results) )
if not no_cache:
cache.parent.mkdir(parents=True, exist_ok=True)
with cache.open("w") as f:
f.write(summaries)
return summaries
@app.command(help="Show the summary of a experiment regex, precompute it if needed")
def summary(filter: Optional[str] = typer.Argument(None), force: bool = False, dump: bool = False, all: bool = False):
if filter is None:
return list_cached_summaries("summary")
def key_mangler(key: str) -> str:
for pattern, sub in (
(r'^val_unscaled_loss_', r'val_uloss_'),
(r'^train_unscaled_loss_', r'train_uloss_'),
(r'^val_loss_', r'val_sloss_'),
(r'^train_loss_', r'train_sloss_'),
):
key = re.sub(pattern, sub, key)
return key
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
if force or not (cache.is_file() and cache.stat().st_size):
compute_summary(filter, force=force, dump=dump)
assert cache.is_file() and cache.stat().st_size, (cache, cache.stat())
if os.isatty(0) and os.isatty(1) and shutil.which("vd"):
rows = read_jsonl(cache)
rows = ({key_mangler(k): v for k, v in row.items()} if row is not None else None for row in rows)
if not all:
rows = filter_jsonl_columns(rows)
rows = ({k: v for k, v in row.items() if not k.startswith(("val_sloss_", "train_sloss_"))} for row in rows)
data = "\n".join(map(json.dumps, rows))
subprocess.run(["vd",
#"--play", EXPERIMENTS / "set-key-columns.vd",
"-f", "jsonl"
], input=data, text=True, check=True)
else:
with cache.open() as f:
print(f.read())
@app.command(help="Filter uninteresting keys from jsonl stdin")
def filter_cols():
rows = map(json.loads, (line for line in sys.stdin.readlines() if line.strip()))
rows = filter_jsonl_columns(rows)
print(*map(json.dumps, rows), sep="\n")
@app.command(help="Run a command for each experiment matched by experiment regex")
def exec(filter: str, cmd: list[str], j: int = typer.Option(1, "-j"), dumped: bool = False, undumped: bool = False):
# inspired by fd / gnu parallel
def populate_cmd(experiment: Path, cmd: Iterable[str]) -> Iterable[str]:
any = False
for i in cmd:
if i == "{}":
any = True
yield str(experiment / "hparams.yaml")
elif i == "{//}":
any = True
yield str(experiment)
else:
yield i
if not any:
yield str(experiment / "hparams.yaml")
with ThreadPoolExecutor(max_workers=j or None) as p:
results = p.map(subprocess.run, (
list(populate_cmd(experiment, cmd))
for experiment in get_experiment_paths(filter)
if not dumped or (experiment / "scalars/epoch.json").is_file()
if not undumped or not (experiment / "scalars/epoch.json").is_file()
))
if any(i.returncode for i in results):
return typer.Exit(1)
@app.command(help="Show stackplot of experiment loss")
def stackplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
if not experiments:
raise typer.BadParameter("No match")
if dump:
try:
tf_dump(experiments)
except typer.BadParameter:
pass
plot_losses(experiments,
mode = "stackplot",
write = write,
dump = dump,
training = training,
unscaled = unscaled,
force = force,
)
@app.command(help="Show stackplot of experiment loss")
def lineplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
if not experiments:
raise typer.BadParameter("No match")
if dump:
try:
tf_dump(experiments)
except typer.BadParameter:
pass
plot_losses(experiments,
mode = "lineplot",
write = write,
dump = dump,
training = training,
unscaled = unscaled,
force = force,
)
@app.command(help="Open tensorboard for the experiments matching the regex")
def tensorboard(filter: Optional[str] = typer.Argument(None), watch: bool = False):
if filter is None:
return list_cached_summaries("tensorboard")
experiments = list(get_experiment_paths(filter, assert_dumped=False))
if not experiments:
raise typer.BadParameter("No match")
with tempfile.TemporaryDirectory(suffix=f"ifield-{filter}") as d:
treefarm = Path(d)
with ThreadPoolExecutor(max_workers=2) as p:
for experiment in experiments:
(treefarm / experiment.name).symlink_to(experiment)
cmd = ["tensorboard", "--logdir", d]
print("+", *map(shlex.quote, cmd), file=sys.stderr)
tensorboard = p.submit(subprocess.run, cmd, check=True)
if not watch:
tensorboard.result()
else:
all_experiments = set(get_experiment_paths(None, assert_dumped=False))
while not tensorboard.done():
time.sleep(10)
new_experiments = set(get_experiment_paths(None, assert_dumped=False)) - all_experiments
if new_experiments:
for experiment in new_experiments:
print(f"Adding {experiment.name!r}...", file=sys.stderr)
(treefarm / experiment.name).symlink_to(experiment)
all_experiments.update(new_experiments)
@app.command(help="Compute evaluation metrics")
def metrics(filter: Optional[str] = typer.Argument(None), dump: bool = False, dry: bool = False, prefix: Optional[str] = typer.Option(None), derive: bool = False, each: bool = False, no_total: bool = False):
if filter is None:
return list_cached_summaries("metrics --derive")
experiments = list(get_experiment_paths(filter, assert_dumped=False))
if not experiments:
raise typer.BadParameter("No match")
if dump:
try:
tf_dump(experiments)
except typer.BadParameter:
pass
def run(*cmd):
if prefix is not None:
cmd = [*shlex.split(prefix), *cmd]
if dry:
print(*map(shlex.quote, map(str, cmd)))
else:
print("+", *map(shlex.quote, map(str, cmd)))
subprocess.run(cmd)
for experiment in experiments:
if no_total: continue
if not (experiment / "compute-scores/metrics.json").is_file():
run(
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics.json",
"--transpose",
)
if not (experiment / "compute-scores/metrics-last.json").is_file():
run(
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics-last.json",
"--transpose",
)
if "2prif-" not in experiment.name: continue
if not (experiment / "compute-scores/metrics-sans_outliers.json").is_file():
run(
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics-sans_outliers.json",
"--transpose", "--filter-outliers"
)
if not (experiment / "compute-scores/metrics-last-sans_outliers.json").is_file():
run(
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics-last-sans_outliers.json",
"--transpose", "--filter-outliers"
)
if dry: return
if prefix is not None:
print("prefix was used, assuming a job scheduler was used, will not print scores.", file=sys.stderr)
return
metrics = [
*(experiment / "compute-scores/metrics.json" for experiment in experiments),
*(experiment / "compute-scores/metrics-last.json" for experiment in experiments),
*(experiment / "compute-scores/metrics-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
*(experiment / "compute-scores/metrics-last-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
]
if not no_total:
assert all(metric.exists() for metric in metrics)
else:
metrics = (metric for metric in metrics if metric.exists())
out = []
for metric in metrics:
experiment = metric.parent.parent.name
is_last = metric.name in ("metrics-last.json", "metrics-last-sans_outliers.json")
with metric.open() as f:
data = json.load(f)
if derive:
derived = {}
objs = [i for i in data.keys() if i != "_hparams"]
for obj in (objs if each else []) + [None]:
if obj is None:
d = DefaultMunch(0)
for obj in objs:
for k, v in data[obj].items():
d[k] += v
obj = "_all_"
n_cd = data["_hparams"]["n_cd"] * len(objs)
n_emd = data["_hparams"]["n_emd"] * len(objs)
else:
d = munchify(data[obj])
n_cd = data["_hparams"]["n_cd"]
n_emd = data["_hparams"]["n_emd"]
precision = d.TP / (d.TP + d.FP)
recall = d.TP / (d.TP + d.FN)
derived[obj] = dict(
filtered = d.n_outliers / d.n if "n_outliers" in d else None,
iou = d.TP / (d.TP + d.FN + d.FP),
precision = precision,
recall = recall,
f_score = 2 * (precision * recall) / (precision + recall),
cd = d.cd_dist / n_cd,
emd = d.emd / n_emd,
cos_med = 1 - (d.cd_cos_med / n_cd) if "cd_cos_med" in d else None,
cos_jac = 1 - (d.cd_cos_jac / n_cd),
)
data = derived if each else derived["_all_"]
data["uid"] = experiment.rsplit("-", 1)[-1]
data["experiment_name"] = experiment
data["is_last"] = is_last
out.append(json.dumps(data))
if derive and not each and os.isatty(0) and os.isatty(1) and shutil.which("vd"):
subprocess.run(["vd", "-f", "jsonl"], input="\n".join(out), text=True, check=True)
else:
print("\n".join(out))
if __name__ == "__main__":
app()