Add code
This commit is contained in:
624
experiments/marf.py
Executable file
624
experiments/marf.py
Executable file
@@ -0,0 +1,624 @@
|
||||
#!/usr/bin/env python3
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from ifield import logging
|
||||
from ifield.cli import CliInterface
|
||||
from ifield.data.common.scan import SingleViewUVScan
|
||||
from ifield.data.coseg import read as coseg_read
|
||||
from ifield.data.stanford import read as stanford_read
|
||||
from ifield.datasets import stanford, coseg, common
|
||||
from ifield.models import intersection_fields
|
||||
from ifield.utils.operators import diff
|
||||
from ifield.viewer.ray_field import ModelViewer
|
||||
from munch import Munch
|
||||
from pathlib import Path
|
||||
from pytorch3d.loss.chamfer import chamfer_distance
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
from tqdm import tqdm
|
||||
from trimesh import Trimesh
|
||||
from typing import Union
|
||||
import builtins
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import rich
|
||||
import rich.pretty
|
||||
import statistics
|
||||
import torch
|
||||
pl.seed_everything(31337)
|
||||
torch.set_float32_matmul_precision('medium')
|
||||
|
||||
|
||||
IField = intersection_fields.IntersectionFieldAutoDecoderModel # brevity
|
||||
|
||||
|
||||
class RayFieldAdDataModuleBase(pl.LightningDataModule, ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def observation_ids(self) -> list[str]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def mk_ad_dataset(self) -> common.AutodecoderDataset:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_trimesh_from_uid(uid) -> Trimesh:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_sphere_scan_from_uid(uid) -> SingleViewUVScan:
|
||||
...
|
||||
|
||||
def setup(self, stage=None):
|
||||
assert stage in ["fit", None] # fit is for train/val, None is for all. "test" not supported ATM
|
||||
|
||||
if not self.hparams.data_dir is None:
|
||||
coseg.config.DATA_PATH = self.hparams.data_dir
|
||||
step = self.hparams.step # brevity
|
||||
|
||||
dataset = self.mk_ad_dataset()
|
||||
n_items_pre_step_mapping = len(dataset)
|
||||
|
||||
if step > 1:
|
||||
dataset = common.TransformExtendedDataset(dataset)
|
||||
|
||||
for sx in range(step):
|
||||
for sy in range(step):
|
||||
def make_slicer(sx, sy, step) -> callable: # the closure is required
|
||||
if step > 1:
|
||||
return lambda t: t[sx::step, sy::step]
|
||||
else:
|
||||
return lambda t: t
|
||||
@dataset.map(slicer=make_slicer(sx, sy, step))
|
||||
def unpack(sample: tuple[str, SingleViewUVScan], slicer: callable):
|
||||
scan: SingleViewUVScan = sample[1]
|
||||
assert not scan.hits.shape[0] % step, f"{scan.hits.shape[0]=} not divisible by {step=}"
|
||||
assert not scan.hits.shape[1] % step, f"{scan.hits.shape[1]=} not divisible by {step=}"
|
||||
|
||||
return {
|
||||
"z_uid" : sample[0],
|
||||
"origins" : scan.cam_pos,
|
||||
"dirs" : slicer(scan.ray_dirs),
|
||||
"points" : slicer(scan.points),
|
||||
"hits" : slicer(scan.hits),
|
||||
"miss" : slicer(scan.miss),
|
||||
"normals" : slicer(scan.normals),
|
||||
"distances" : slicer(scan.distances),
|
||||
}
|
||||
|
||||
# Split each object into train/val with SampleSplit
|
||||
n_items = len(dataset)
|
||||
n_val = int(n_items * self.hparams.val_fraction)
|
||||
n_train = n_items - n_val
|
||||
self.generator = torch.Generator().manual_seed(self.hparams.prng_seed)
|
||||
|
||||
# split the dataset such that all steps are in same part
|
||||
assert n_items == n_items_pre_step_mapping * step * step, (n_items, n_items_pre_step_mapping, step)
|
||||
indices = [
|
||||
i*step*step + sx*step + sy
|
||||
for i in torch.randperm(n_items_pre_step_mapping, generator=self.generator).tolist()
|
||||
for sx in range(step)
|
||||
for sy in range(step)
|
||||
]
|
||||
self.dataset_train = Subset(dataset, sorted(indices[:n_train], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
|
||||
self.dataset_val = Subset(dataset, sorted(indices[n_train:n_train+n_val], key=lambda x: torch.rand(1, generator=self.generator).tolist()[0]))
|
||||
|
||||
assert len(self.dataset_train) % self.hparams.batch_size == 0
|
||||
assert len(self.dataset_val) % self.hparams.batch_size == 0
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.dataset_train,
|
||||
batch_size = self.hparams.batch_size,
|
||||
drop_last = self.hparams.drop_last,
|
||||
num_workers = self.hparams.num_workers,
|
||||
persistent_workers = self.hparams.persistent_workers,
|
||||
pin_memory = self.hparams.pin_memory,
|
||||
prefetch_factor = self.hparams.prefetch_factor,
|
||||
shuffle = self.hparams.shuffle,
|
||||
generator = self.generator,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.dataset_val,
|
||||
batch_size = self.hparams.batch_size,
|
||||
drop_last = self.hparams.drop_last,
|
||||
num_workers = self.hparams.num_workers,
|
||||
persistent_workers = self.hparams.persistent_workers,
|
||||
pin_memory = self.hparams.pin_memory,
|
||||
prefetch_factor = self.hparams.prefetch_factor,
|
||||
generator = self.generator,
|
||||
)
|
||||
|
||||
|
||||
class StanfordUVDataModule(RayFieldAdDataModuleBase):
|
||||
skyward = "+Z"
|
||||
def __init__(self,
|
||||
data_dir : Union[str, Path, None] = None,
|
||||
obj_names : list[str] = ["bunny"], # empty means all
|
||||
|
||||
prng_seed : int = 1337,
|
||||
step : int = 2,
|
||||
batch_size : int = 5,
|
||||
drop_last : bool = False,
|
||||
num_workers : int = 8,
|
||||
persistent_workers : bool = True,
|
||||
pin_memory : int = True,
|
||||
prefetch_factor : int = 2,
|
||||
shuffle : bool = True,
|
||||
val_fraction : float = 0.30,
|
||||
):
|
||||
super().__init__()
|
||||
if not obj_names:
|
||||
obj_names = stanford_read.list_object_names()
|
||||
self.save_hyperparameters()
|
||||
|
||||
@property
|
||||
def observation_ids(self) -> list[str]:
|
||||
return self.hparams.obj_names
|
||||
|
||||
def mk_ad_dataset(self) -> common.AutodecoderDataset:
|
||||
return stanford.AutodecoderSingleViewUVScanDataset(
|
||||
obj_names = self.hparams.obj_names,
|
||||
data_path = self.hparams.data_dir,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_trimesh_from_uid(obj_name) -> Trimesh:
|
||||
import mesh_to_sdf
|
||||
mesh = stanford_read.read_mesh(obj_name)
|
||||
return mesh_to_sdf.scale_to_unit_sphere(mesh)
|
||||
|
||||
@staticmethod
|
||||
def get_sphere_scan_from_uid(obj_name) -> SingleViewUVScan:
|
||||
return stanford_read.read_mesh_mesh_sphere_scan(obj_name)
|
||||
|
||||
|
||||
class CosegUVDataModule(RayFieldAdDataModuleBase):
|
||||
skyward = "+Y"
|
||||
def __init__(self,
|
||||
data_dir : Union[str, Path, None] = None,
|
||||
object_sets : tuple[str] = ["tele-aliens"], # empty means all
|
||||
|
||||
prng_seed : int = 1337,
|
||||
step : int = 2,
|
||||
batch_size : int = 5,
|
||||
drop_last : bool = False,
|
||||
num_workers : int = 8,
|
||||
persistent_workers : bool = True,
|
||||
pin_memory : int = True,
|
||||
prefetch_factor : int = 2,
|
||||
shuffle : bool = True,
|
||||
val_fraction : float = 0.30,
|
||||
):
|
||||
super().__init__()
|
||||
if not object_sets:
|
||||
object_sets = coseg_read.list_object_sets()
|
||||
object_sets = tuple(object_sets)
|
||||
self.save_hyperparameters()
|
||||
|
||||
@property
|
||||
def observation_ids(self) -> list[str]:
|
||||
return coseg_read.list_model_id_strings(self.hparams.object_sets)
|
||||
|
||||
def mk_ad_dataset(self) -> common.AutodecoderDataset:
|
||||
return coseg.AutodecoderSingleViewUVScanDataset(
|
||||
object_sets = self.hparams.object_sets,
|
||||
data_path = self.hparams.data_dir,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_trimesh_from_uid(string_uid):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_sphere_scan_from_uid(string_uid) -> SingleViewUVScan:
|
||||
uid = coseg_read.model_id_string_to_uid(string_uid)
|
||||
return coseg_read.read_mesh_mesh_sphere_scan(*uid)
|
||||
|
||||
|
||||
def mk_cli(args=None) -> CliInterface:
|
||||
cli = CliInterface(
|
||||
module_cls = IField,
|
||||
datamodule_cls = [StanfordUVDataModule, CosegUVDataModule],
|
||||
workdir = Path(__file__).parent.resolve(),
|
||||
experiment_name_prefix = "ifield",
|
||||
)
|
||||
cli.trainer_defaults.update(dict(
|
||||
precision = 16,
|
||||
min_epochs = 5,
|
||||
))
|
||||
|
||||
@cli.register_pre_training_callback
|
||||
def populate_autodecoder_z_uids(args: Namespace, config: Munch, module: IField, trainer: pl.Trainer, datamodule: RayFieldAdDataModuleBase, logger: logging.Logger):
|
||||
module.set_observation_ids(datamodule.observation_ids)
|
||||
rank = getattr(rank_zero_only, "rank", 0)
|
||||
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) = }")
|
||||
rich.print(f"[rank {rank}] {len(datamodule.observation_ids) > 1 = }")
|
||||
rich.print(f"[rank {rank}] {module.is_conditioned = }")
|
||||
|
||||
@cli.register_action(help="Interactive window with direct renderings from the model", args=[
|
||||
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
|
||||
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
|
||||
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
|
||||
("--analytical-normals", dict(action="store_true")),
|
||||
("--ground-truth", dict(action="store_true")),
|
||||
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
|
||||
("--res", dict(type=int, nargs=2, default=(210, 160), help="Rendering resolution")),
|
||||
("--bg", dict(choices=["map", "white", "black"], default="map")),
|
||||
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
|
||||
("--scale", dict(type=int, default=3, help="Rendering scale")),
|
||||
("--fps", dict(type=int, default=None, help="FPS upper limit")),
|
||||
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
|
||||
("--write", dict(type=Path, default=None, help="Where to write a screenshot.")),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def viewer(args: Namespace, config: Munch, model: IField):
|
||||
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
|
||||
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
|
||||
name = config.experiment_name,
|
||||
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
|
||||
res = args.res,
|
||||
skyward = args.skyward,
|
||||
scale = args.scale,
|
||||
mesh_gt_getter = datamodule_cls.get_trimesh_from_uid,
|
||||
)
|
||||
viewer.display_mode_shading = args.shading
|
||||
viewer.display_mode_centroid = args.centroid
|
||||
viewer.display_mode_spheres = args.spheres
|
||||
if args.ground_truth: viewer.display_mode_normals = viewer.vizmodes_normals.index("ground_truth")
|
||||
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
|
||||
viewer.atom_index_solo = args.solo_atom
|
||||
viewer.fps_cap = args.fps
|
||||
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
|
||||
if args.cam_state is not None:
|
||||
viewer.cam_state = json.loads(args.cam_state)
|
||||
if args.write is None:
|
||||
viewer.run()
|
||||
else:
|
||||
assert args.write.suffix == ".png", args.write.name
|
||||
viewer.render_headless(args.write,
|
||||
n_frames = 1,
|
||||
fps = 1,
|
||||
state_callback = None,
|
||||
)
|
||||
|
||||
@cli.register_action(help="Prerender direct renderings from the model", args=[
|
||||
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
|
||||
("uids", dict(type=str, nargs="*")),
|
||||
("--frames", dict(type=int, default=60, help="Number of per interpolation. Default is 60")),
|
||||
("--fps", dict(type=int, default=60, help="Default is 60")),
|
||||
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
|
||||
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
|
||||
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
|
||||
("--analytical-normals", dict(action="store_true")),
|
||||
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
|
||||
("--res", dict(type=int, nargs=2, default=(240, 240), help="Rendering resolution. Default is 240 240")),
|
||||
("--bg", dict(choices=["map", "white", "black"], default="map")),
|
||||
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
|
||||
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
|
||||
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def render_video_interpolation(args: Namespace, config: Munch, model: IField, **kw):
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
uids = args.uids or list(model.keys())
|
||||
assert len(uids) > 1
|
||||
if not args.uids: uids.append(uids[0])
|
||||
viewer = ModelViewer(model, uids[0],
|
||||
name = config.experiment_name,
|
||||
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
|
||||
res = args.res,
|
||||
skyward = args.skyward,
|
||||
)
|
||||
if args.cam_state is not None:
|
||||
viewer.cam_state = json.loads(args.cam_state)
|
||||
viewer.display_mode_shading = args.shading
|
||||
viewer.display_mode_centroid = args.centroid
|
||||
viewer.display_mode_spheres = args.spheres
|
||||
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
|
||||
viewer.atom_index_solo = args.solo_atom
|
||||
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
|
||||
def state_callback(self: ModelViewer, frame: int):
|
||||
if frame % args.frames:
|
||||
self.lambertian_color = (0.8, 0.8, 1.0)
|
||||
else:
|
||||
self.lambertian_color = (1.0, 1.0, 1.0)
|
||||
self.fps = args.frames
|
||||
idx = frame // args.frames + 1
|
||||
if idx != len(uids):
|
||||
self.current_uid = uids[idx]
|
||||
print(f"Writing video to {str(args.output_path)!r}...")
|
||||
viewer.render_headless(args.output_path,
|
||||
n_frames = args.frames * (len(uids)-1) + 1,
|
||||
fps = args.fps,
|
||||
state_callback = state_callback,
|
||||
bitrate = args.bitrate,
|
||||
)
|
||||
|
||||
@cli.register_action(help="Prerender direct renderings from the model", args=[
|
||||
("output_path",dict(type=Path, help="Where to store the output. We recommend a .mp4 suffix.")),
|
||||
("--frames", dict(type=int, default=180, help="Number of frames. Default is 180")),
|
||||
("--fps", dict(type=int, default=60, help="Default is 60")),
|
||||
("--shading", dict(type=int, default=ModelViewer.vizmodes_shading .index("lambertian"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_shading))}}}")),
|
||||
("--centroid", dict(type=int, default=ModelViewer.vizmodes_centroids.index("best-centroids-colored"), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_centroids))}}}")),
|
||||
("--spheres", dict(type=int, default=ModelViewer.vizmodes_spheres .index(None), help=f"Rendering mode. {{{', '.join(f'{i}: {m!r}'for i, m in enumerate(ModelViewer.vizmodes_spheres))}}}")),
|
||||
("--analytical-normals", dict(action="store_true")),
|
||||
("--solo-atom",dict(type=int, default=None, help="Rendering mode")),
|
||||
("--res", dict(type=int, nargs=2, default=(320, 240), help="Rendering resolution. Default is 320 240")),
|
||||
("--bg", dict(choices=["map", "white", "black"], default="map")),
|
||||
("--skyward", dict(type=str, default="+Z", help='one of: "+X", "-X", "+Y", "-Y", ["+Z"], "-Z"')),
|
||||
("--bitrate", dict(type=str, default="1500k", help="Encoding bitrate. Default is 1500k")),
|
||||
("--cam-state",dict(type=str, default=None, help="json cam state, expored with CTRL+H")),
|
||||
])
|
||||
@torch.no_grad()
|
||||
def render_video_spin(args: Namespace, config: Munch, model: IField, **kw):
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
viewer = ModelViewer(model, start_uid=next(iter(model.keys())),
|
||||
name = config.experiment_name,
|
||||
screenshot_dir = Path(__file__).parent.parent / "images/pygame-viewer",
|
||||
res = args.res,
|
||||
skyward = args.skyward,
|
||||
)
|
||||
if args.cam_state is not None:
|
||||
viewer.cam_state = json.loads(args.cam_state)
|
||||
viewer.display_mode_shading = args.shading
|
||||
viewer.display_mode_centroid = args.centroid
|
||||
viewer.display_mode_spheres = args.spheres
|
||||
if args.analytical_normals: viewer.display_mode_normals = viewer.vizmodes_normals.index("analytical")
|
||||
viewer.atom_index_solo = args.solo_atom
|
||||
viewer.display_sphere_map_bg = { "map": True, "white": 255, "black": 0 }[args.bg]
|
||||
cam_rot_x_init = viewer.cam_rot_x
|
||||
def state_callback(self: ModelViewer, frame: int):
|
||||
self.cam_rot_x = cam_rot_x_init + 3.14 * (frame / args.frames) * 2
|
||||
print(f"Writing video to {str(args.output_path)!r}...")
|
||||
viewer.render_headless(args.output_path,
|
||||
n_frames = args.frames,
|
||||
fps = args.fps,
|
||||
state_callback = state_callback,
|
||||
bitrate = args.bitrate,
|
||||
)
|
||||
|
||||
@cli.register_action(help="foo", args=[
|
||||
("fname", dict(type=Path, help="where to write json")),
|
||||
("-t", "--transpose", dict(action="store_true", help="transpose the output")),
|
||||
("--single-shape", dict(action="store_true", help="break after first shape")),
|
||||
("--batch-size", dict(type=int, default=40_000, help="tradeoff between vram usage and efficiency")),
|
||||
("--n-cd", dict(type=int, default=30_000, help="Number of points to use when computing chamfer distance")),
|
||||
("--filter-outliers", dict(action="store_true", help="like in PRIF")),
|
||||
])
|
||||
@torch.enable_grad()
|
||||
def compute_scores(args: Namespace, config: Munch, model: IField, **kw):
|
||||
datamodule_cls: RayFieldAdDataModuleBase = cli.get_datamodule_cls_from_config(args, config)
|
||||
model.eval()
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
|
||||
model.to("cuda")
|
||||
|
||||
def T(array: np.ndarray, **kw) -> torch.Tensor:
|
||||
if isinstance(array, torch.Tensor): return array
|
||||
return torch.tensor(array, device=model.device, dtype=model.dtype if isinstance(array, np.floating) else None, **kw)
|
||||
|
||||
MEDIAL = model.hparams.output_mode == "medial_sphere"
|
||||
if not MEDIAL: assert model.hparams.output_mode == "orthogonal_plane"
|
||||
|
||||
|
||||
uids = sorted(model.keys())
|
||||
if args.single_shape: uids = [uids[0]]
|
||||
rich.print(f"{datamodule_cls.__name__ = }")
|
||||
rich.print(f"{len(uids) = }")
|
||||
|
||||
# accumulators for IoU and F-Score, CD and COS
|
||||
|
||||
# sum reduction:
|
||||
n = defaultdict(int)
|
||||
n_gt_hits = defaultdict(int)
|
||||
n_gt_miss = defaultdict(int)
|
||||
n_gt_missing = defaultdict(int)
|
||||
n_outliers = defaultdict(int)
|
||||
p_mse = defaultdict(int)
|
||||
s_mse = defaultdict(int)
|
||||
cossim_med = defaultdict(int) # medial normals
|
||||
cossim_jac = defaultdict(int) # jacovian normals
|
||||
TP,FN,FP,TN = [defaultdict(int) for _ in range(4)] # IoU and f-score
|
||||
# mean reduction:
|
||||
cd_dist = {} # chamfer distance
|
||||
cd_cos_med = {} # chamfer medial normals
|
||||
cd_cos_jac = {} # chamfer jacovian normals
|
||||
all_metrics = dict(
|
||||
n=n, n_gt_hits=n_gt_hits, n_gt_miss=n_gt_miss, n_gt_missing=n_gt_missing, p_mse=p_mse,
|
||||
cossim_jac=cossim_jac,
|
||||
TP=TP, FN=FN, FP=FP, TN=TN, cd_dist=cd_dist,
|
||||
cd_cos_jac=cd_cos_jac,
|
||||
)
|
||||
if MEDIAL:
|
||||
all_metrics["s_mse"] = s_mse
|
||||
all_metrics["cossim_med"] = cossim_med
|
||||
all_metrics["cd_cos_med"] = cd_cos_med
|
||||
if args.filter_outliers:
|
||||
all_metrics["n_outliers"] = n_outliers
|
||||
|
||||
t = datetime.now()
|
||||
for uid in tqdm(uids, desc="Dataset", position=0, leave=True, disable=len(uids)<=1):
|
||||
sphere_scan_gt = datamodule_cls.get_sphere_scan_from_uid(uid)
|
||||
|
||||
z = model[uid].detach()
|
||||
|
||||
all_intersections = []
|
||||
all_medial_normals = []
|
||||
all_jacobian_normals = []
|
||||
|
||||
step = args.batch_size
|
||||
for i in tqdm(range(0, sphere_scan_gt.hits.shape[0], step), desc=f"Item {uid!r}", position=1, leave=False):
|
||||
# prepare batch and gt
|
||||
origins = T(sphere_scan_gt.cam_pos [i:i+step, :], requires_grad = True)
|
||||
dirs = T(sphere_scan_gt.ray_dirs [i:i+step, :])
|
||||
gt_hits = T(sphere_scan_gt.hits [i:i+step])
|
||||
gt_miss = T(sphere_scan_gt.miss [i:i+step])
|
||||
gt_missing = T(sphere_scan_gt.missing [i:i+step])
|
||||
gt_points = T(sphere_scan_gt.points [i:i+step, :])
|
||||
gt_normals = T(sphere_scan_gt.normals [i:i+step, :])
|
||||
gt_distances = T(sphere_scan_gt.distances[i:i+step])
|
||||
|
||||
# forward
|
||||
if MEDIAL:
|
||||
(
|
||||
depths,
|
||||
silhouettes,
|
||||
intersections,
|
||||
medial_normals,
|
||||
is_intersecting,
|
||||
sphere_centers,
|
||||
sphere_radii,
|
||||
) = model({
|
||||
"origins" : origins,
|
||||
"dirs" : dirs,
|
||||
}, z, intersections_only=False, allow_nans=False)
|
||||
else:
|
||||
silhouettes = medial_normals = None
|
||||
intersections, is_intersecting = model({
|
||||
"origins" : origins,
|
||||
"dirs" : dirs,
|
||||
}, z, normalize_origins = True)
|
||||
is_intersecting = is_intersecting > 0.5
|
||||
jac = diff.jacobian(intersections, origins, detach=True)
|
||||
|
||||
# outlier removal (PRIF)
|
||||
if args.filter_outliers:
|
||||
outliers = jac.norm(dim=-2).norm(dim=-1) > 5
|
||||
n_outliers[uid] += outliers[is_intersecting].sum().item()
|
||||
# We count filtered points as misses
|
||||
is_intersecting &= ~outliers
|
||||
|
||||
model.zero_grad()
|
||||
jacobian_normals = model.compute_normals_from_intersection_origin_jacobian(jac, dirs)
|
||||
|
||||
all_intersections .append(intersections .detach()[is_intersecting.detach(), :])
|
||||
all_medial_normals .append(medial_normals .detach()[is_intersecting.detach(), :]) if MEDIAL else None
|
||||
all_jacobian_normals.append(jacobian_normals.detach()[is_intersecting.detach(), :])
|
||||
|
||||
# accumulate metrics
|
||||
with torch.no_grad():
|
||||
n [uid] += dirs.shape[0]
|
||||
n_gt_hits [uid] += gt_hits.sum().item()
|
||||
n_gt_miss [uid] += gt_miss.sum().item()
|
||||
n_gt_missing [uid] += gt_missing.sum().item()
|
||||
p_mse [uid] += (gt_points [gt_hits, :] - intersections[gt_hits, :]).norm(2, dim=-1).pow(2).sum().item()
|
||||
if MEDIAL: s_mse [uid] += (gt_distances[gt_miss] - silhouettes [gt_miss] ) .pow(2).sum().item()
|
||||
if MEDIAL: cossim_med[uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], medial_normals [gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
|
||||
cossim_jac [uid] += (1-F.cosine_similarity(gt_normals[gt_hits, :], jacobian_normals[gt_hits, :], dim=-1).abs()).sum().item() # to match what pytorch3d does for CD
|
||||
not_intersecting = ~is_intersecting
|
||||
TP [uid] += ((gt_hits | gt_missing) & is_intersecting).sum().item() # True Positive
|
||||
FN [uid] += ((gt_hits | gt_missing) & not_intersecting).sum().item() # False Negative
|
||||
FP [uid] += (gt_miss & is_intersecting).sum().item() # False Positive
|
||||
TN [uid] += (gt_miss & not_intersecting).sum().item() # True Negative
|
||||
|
||||
all_intersections = torch.cat(all_intersections, dim=0)
|
||||
all_medial_normals = torch.cat(all_medial_normals, dim=0) if MEDIAL else None
|
||||
all_jacobian_normals = torch.cat(all_jacobian_normals, dim=0)
|
||||
|
||||
hits = sphere_scan_gt.hits # brevity
|
||||
print()
|
||||
|
||||
assert all_intersections.shape[0] >= args.n_cd
|
||||
idx_cd_pred = torch.randperm(all_intersections.shape[0])[:args.n_cd]
|
||||
idx_cd_gt = torch.randperm(hits.sum()) [:args.n_cd]
|
||||
|
||||
print("cd... ", end="")
|
||||
tt = datetime.now()
|
||||
loss_cd, loss_cos_jac = chamfer_distance(
|
||||
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
x_normals = all_jacobian_normals [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
|
||||
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
|
||||
batch_reduction = "sum", point_reduction = "sum",
|
||||
)
|
||||
if MEDIAL: _, loss_cos_med = chamfer_distance(
|
||||
x = all_intersections [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
x_normals = all_medial_normals [None, :, :][:, idx_cd_pred, :].detach(),
|
||||
y = T(sphere_scan_gt.points [None, hits, :][:, idx_cd_gt, :]),
|
||||
y_normals = T(sphere_scan_gt.normals[None, hits, :][:, idx_cd_gt, :]),
|
||||
batch_reduction = "sum", point_reduction = "sum",
|
||||
)
|
||||
print(datetime.now() - tt)
|
||||
|
||||
cd_dist [uid] = loss_cd.item()
|
||||
cd_cos_med [uid] = loss_cos_med.item() if MEDIAL else None
|
||||
cd_cos_jac [uid] = loss_cos_jac.item()
|
||||
|
||||
print()
|
||||
model.zero_grad(set_to_none=True)
|
||||
print("Total time:", datetime.now() - t)
|
||||
print("Time per item:", (datetime.now() - t) / len(uids)) if len(uids) > 1 else None
|
||||
|
||||
sum = lambda *xs: builtins .sum (itertools.chain(*(x.values() for x in xs)))
|
||||
mean = lambda *xs: statistics.mean (itertools.chain(*(x.values() for x in xs)))
|
||||
stdev = lambda *xs: statistics.stdev(itertools.chain(*(x.values() for x in xs)))
|
||||
n_cd = args.n_cd
|
||||
P = sum(TP)/(sum(TP, FP))
|
||||
R = sum(TP)/(sum(TP, FN))
|
||||
print(f"{mean(n) = :11.1f} (rays per object)")
|
||||
print(f"{mean(n_gt_hits) = :11.1f} (gt rays hitting per object)")
|
||||
print(f"{mean(n_gt_miss) = :11.1f} (gt rays missing per object)")
|
||||
print(f"{mean(n_gt_missing) = :11.1f} (gt rays unknown per object)")
|
||||
print(f"{mean(n_outliers) = :11.1f} (gt rays unknown per object)") if args.filter_outliers else None
|
||||
print(f"{n_cd = :11.0f} (cd rays per object)")
|
||||
print(f"{mean(n_gt_hits) / mean(n) = :11.8f} (fraction rays hitting per object)")
|
||||
print(f"{mean(n_gt_miss) / mean(n) = :11.8f} (fraction rays missing per object)")
|
||||
print(f"{mean(n_gt_missing)/ mean(n) = :11.8f} (fraction rays unknown per object)")
|
||||
print(f"{mean(n_outliers) / mean(n) = :11.8f} (fraction rays unknown per object)") if args.filter_outliers else None
|
||||
print(f"{sum(TP)/sum(n) = :11.8f} (total ray TP)")
|
||||
print(f"{sum(TN)/sum(n) = :11.8f} (total ray TN)")
|
||||
print(f"{sum(FP)/sum(n) = :11.8f} (total ray FP)")
|
||||
print(f"{sum(FN)/sum(n) = :11.8f} (total ray FN)")
|
||||
print(f"{sum(TP, FN, FP)/sum(n) = :11.8f} (total ray union)")
|
||||
print(f"{sum(TP)/sum(TP, FN, FP) = :11.8f} (total ray IoU)")
|
||||
print(f"{sum(TP)/(sum(TP, FP)) = :11.8f} -> P (total ray precision)")
|
||||
print(f"{sum(TP)/(sum(TP, FN)) = :11.8f} -> R (total ray recall)")
|
||||
print(f"{2*(P*R)/(P+R) = :11.8f} (total ray F-score)")
|
||||
print(f"{sum(p_mse)/sum(n_gt_hits) = :11.8f} (mean ray intersection mean squared error)")
|
||||
print(f"{sum(s_mse)/sum(n_gt_miss) = :11.8f} (mean ray silhoutette mean squared error)")
|
||||
print(f"{sum(cossim_med)/sum(n_gt_hits) = :11.8f} (mean ray medial reduced cosine similarity)") if MEDIAL else None
|
||||
print(f"{sum(cossim_jac)/sum(n_gt_hits) = :11.8f} (mean ray analytical reduced cosine similarity)")
|
||||
print(f"{mean(cd_dist) /n_cd * 1e3 = :11.8f} (mean chamfer distance)")
|
||||
print(f"{mean(cd_cos_med)/n_cd = :11.8f} (mean chamfer reduced medial cossim distance)") if MEDIAL else None
|
||||
print(f"{mean(cd_cos_jac)/n_cd = :11.8f} (mean chamfer reduced analytical cossim distance)")
|
||||
print(f"{stdev(cd_dist) /n_cd * 1e3 = :11.8f} (stdev chamfer distance)") if len(cd_dist) > 1 else None
|
||||
print(f"{stdev(cd_cos_med)/n_cd = :11.8f} (stdev chamfer reduced medial cossim distance)") if len(cd_cos_med) > 1 and MEDIAL else None
|
||||
print(f"{stdev(cd_cos_jac)/n_cd = :11.8f} (stdev chamfer reduced analytical cossim distance)") if len(cd_cos_jac) > 1 else None
|
||||
|
||||
if args.transpose:
|
||||
all_metrics, old_metrics = defaultdict(dict), all_metrics
|
||||
for m, table in old_metrics.items():
|
||||
for uid, vals in table.items():
|
||||
all_metrics[uid][m] = vals
|
||||
all_metrics["_hparams"] = dict(n_cd=args.n_cd)
|
||||
else:
|
||||
all_metrics["n_cd"] = args.n_cd
|
||||
|
||||
if str(args.fname) == "-":
|
||||
print("{", ',\n'.join(
|
||||
f" {json.dumps(k)}: {json.dumps(v)}"
|
||||
for k, v in all_metrics.items()
|
||||
), "}", sep="\n")
|
||||
else:
|
||||
args.fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
with args.fname.open("w") as f:
|
||||
json.dump(all_metrics, f, indent=2)
|
||||
|
||||
return cli
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mk_cli().run()
|
||||
263
experiments/marf.yaml.j2
Executable file
263
experiments/marf.yaml.j2
Executable file
@@ -0,0 +1,263 @@
|
||||
#!/usr/bin/env -S python ./marf.py module
|
||||
{% do require_defined("select", select, 0, "$SLURM_ARRAY_TASK_ID") %}{# requires jinja2.ext.do #}
|
||||
{% do require_defined("mode", mode, "single", "ablation", "multi", strict=true, exchaustive=true) %}{# requires jinja2.ext.do #}
|
||||
{% set counter = itertools.count(start=0, step=1) %}
|
||||
{% set do_condition = mode == "multi" %}
|
||||
{% set do_ablation = mode == "ablation" %}
|
||||
|
||||
{% set hp_matrix = namespace() %}{# hyper parameter matrix #}
|
||||
|
||||
{% set hp_matrix.input_mode = [
|
||||
"both",
|
||||
"perp_foot",
|
||||
"plucker",
|
||||
] if do_ablation else [ "both" ] %}
|
||||
{% set hp_matrix.output_mode = ["medial_sphere", "orthogonal_plane"] %}{##}
|
||||
{% set hp_matrix.output_mode = ["medial_sphere"] %}{##}
|
||||
{% set hp_matrix.n_atoms = [16, 1, 4, 8, 32, 64] if do_ablation else [16] %}{##}
|
||||
{% set hp_matrix.normal_coeff = [0.25, 0] if do_ablation else [0.25] %}{##}
|
||||
{% set hp_matrix.dataset_item = [objname] if objname is defined else (["armadillo", "bunny", "happy_buddha", "dragon", "lucy"] if not do_condition else ["four-legged"]) %}{##}
|
||||
{% set hp_matrix.test_val_split_frac = [0.7] %}{##}
|
||||
{% set hp_matrix.lr_coeff = [5] %}{##}
|
||||
{% set hp_matrix.warmup_epochs = [1] if not do_condition else [0.1] %}{##}
|
||||
{% set hp_matrix.improve_miss_grads = [True] %}{##}
|
||||
{% set hp_matrix.normalize_ray_dirs = [True] %}{##}
|
||||
{% set hp_matrix.intersection_coeff = [2, 0] if do_ablation else [2] %}{##}
|
||||
{% set hp_matrix.miss_distance_coeff = [1, 0, 5] if do_ablation else [1] %}{##}
|
||||
{% set hp_matrix.relative_out = [False] %}{##}
|
||||
{% set hp_matrix.hidden_features = [512] %}{# like deepsdf and prif #}
|
||||
{% set hp_matrix.hidden_layers = [8] %}{# like deepsdf, nerf, prif #}
|
||||
{% set hp_matrix.nonlinearity = ["leaky_relu"] %}{##}
|
||||
{% set hp_matrix.omega = [30] %}{##}
|
||||
{% set hp_matrix.normalization = ["layernorm"] %}{##}
|
||||
{% set hp_matrix.dropout_percent = [1] %}{##}
|
||||
{% set hp_matrix.sphere_grow_reg_coeff = [500, 0, 5000] if do_ablation else [500] %}{##}
|
||||
{% set hp_matrix.geom_init = [True, False] if do_ablation else [True] %}{##}
|
||||
{% set hp_matrix.loss_inscription = [50, 0, 250] if do_ablation else [50] %}{##}
|
||||
{% set hp_matrix.atom_centroid_norm_std_reg_negexp = [0, None] if do_ablation else [0] %}{##}
|
||||
{% set hp_matrix.curvature_reg_coeff = [0.2] %}{##}
|
||||
{% set hp_matrix.multi_view_reg_coeff = [1, 2] if do_ablation else [1] %}{##}
|
||||
{% set hp_matrix.grad_reg = [ "multi_view", "nogradreg" ] if do_ablation else [ "multi_view" ] %}
|
||||
|
||||
{#% for hp in cartesian_hparams(hp_matrix) %}{##}
|
||||
{% for hp in ablation_hparams(hp_matrix, caartesian_keys=["output_mode", "dataset_item", "nonlinearity", "test_val_split_frac"]) %}
|
||||
|
||||
{% if hp.output_mode == "orthogonal_plane"%}
|
||||
{% if hp.normal_coeff == 0 %}{% set hp.normal_coeff = 0.25 %}
|
||||
{% elif hp.normal_coeff == 0.25 %}{% set hp.normal_coeff = 0 %}
|
||||
{% endif %}
|
||||
{% if hp.grad_reg == "multi_view" %}{% set hp.grad_reg = "nogradreg" %}
|
||||
{% elif hp.grad_reg == "nogradreg" %}{% set hp.grad_reg = "multi_view" %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{# filter bad/uninteresting hparam combos #}
|
||||
{% if ( hp.nonlinearity != "sine" and hp.omega != 30 )
|
||||
or ( hp.nonlinearity == "sine" and hp.normalization in ("layernorm", "layernorm_na") )
|
||||
or ( hp.multi_view_reg_coeff != 1 and "multi_view" not in hp.grad_reg )
|
||||
or ( "curvature" not in hp.grad_reg and hp.curvature_reg_coeff != 0.2 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.input_mode != "both" )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.atom_centroid_norm_std_reg_negexp != 0 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.n_atoms != 16 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.sphere_grow_reg_coeff != 500 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.loss_inscription != 50 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.miss_distance_coeff != 1 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.test_val_split_frac != 0.7 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and hp.lr_coeff != 5 )
|
||||
or ( hp.output_mode == "orthogonal_plane" and not hp.geom_init )
|
||||
or ( hp.output_mode == "orthogonal_plane" and not hp.intersection_coeff )
|
||||
%}
|
||||
{% continue %}{# requires jinja2.ext.loopcontrols #}
|
||||
{% endif %}
|
||||
|
||||
{% set index = next(counter) %}
|
||||
{% if select is not defined and index > 0 %}---{% endif %}
|
||||
{% if select is not defined or int(select) == index %}
|
||||
|
||||
trainer:
|
||||
gradient_clip_val : 1.0
|
||||
max_epochs : 200
|
||||
min_epochs : 200
|
||||
log_every_n_steps : 20
|
||||
|
||||
{% if not do_condition %}
|
||||
|
||||
StanfordUVDataModule:
|
||||
obj_names : ["{{ hp.dataset_item }}"]
|
||||
step : 4
|
||||
batch_size : 8
|
||||
val_fraction : {{ 1-hp.test_val_split_frac }}
|
||||
|
||||
{% else %}{# if do_condition #}
|
||||
|
||||
CosegUVDataModule:
|
||||
object_sets : ["{{ hp.dataset_item }}"]
|
||||
step : 4
|
||||
batch_size : 8
|
||||
val_fraction : {{ 1-hp.test_val_split_frac }}
|
||||
|
||||
{% endif %}{# if do_condition #}
|
||||
|
||||
logging:
|
||||
save_dir : logdir
|
||||
type : tensorboard
|
||||
project : ifield
|
||||
|
||||
{% autoescape false %}
|
||||
{% do require_defined("experiment_name", experiment_name, "single-shape" if do_condition else "multi-shape", strict=true) %}
|
||||
{% set input_mode_abbr = hp.input_mode
|
||||
.replace("plucker", "plkr")
|
||||
.replace("perp_foot", "prpft")
|
||||
%}
|
||||
{% set output_mode_abbr = hp.output_mode
|
||||
.replace("medial_sphere", "marf")
|
||||
.replace("orthogonal_plane", "prif")
|
||||
%}
|
||||
experiment_name: experiment-{{ "" if experiment_name is not defined else experiment_name }}
|
||||
{#--#}-{{ hp.dataset_item }}
|
||||
{#--#}-{{ input_mode_abbr }}2{{ output_mode_abbr }}
|
||||
{#--#}
|
||||
{%- if hp.output_mode == "medial_sphere" -%}
|
||||
{#--#}-{{ hp.n_atoms }}atom
|
||||
{#--# }-{{ "rel" if hp.relative_out else "norel" }}
|
||||
{#--# }-{{ "e" if hp.improve_miss_grads else "0" }}sqrt
|
||||
{#--#}-{{ int(hp.loss_inscription) if hp.loss_inscription else "no" }}xinscr
|
||||
{#--#}-{{ int(hp.miss_distance_coeff * 10) }}dmiss
|
||||
{#--#}-{{ "geom" if hp.geom_init else "nogeom" }}
|
||||
{#--#}{% if "curvature" in hp.grad_reg %}
|
||||
{#- -#}-{{ int(hp.curvature_reg_coeff*10) }}crv
|
||||
{#--#}{%- endif -%}
|
||||
{%- elif hp.output_mode == "orthogonal_plane" -%}
|
||||
{#--#}
|
||||
{%- endif -%}
|
||||
{#--#}-{{ int(hp.intersection_coeff*10) }}chit
|
||||
{#--#}-{{ int(hp.normal_coeff*100) or "no" }}cnrml
|
||||
{#--# }-{{ "do" if hp.normalize_ray_dirs else "no" }}raynorm
|
||||
{#--#}-{{ hp.hidden_layers }}x{{ hp.hidden_features }}fc
|
||||
{#--#}-{{ hp.nonlinearity or "linear" }}
|
||||
{#--#}
|
||||
{%- if hp.nonlinearity == "sine" -%}
|
||||
{#--#}-{{ hp.omega }}omega
|
||||
{#--#}
|
||||
{%- endif -%}
|
||||
{%- if hp.output_mode == "medial_sphere" -%}
|
||||
{#--#}-{{ str(hp.atom_centroid_norm_std_reg_negexp).replace(*"-n") if hp.atom_centroid_norm_std_reg_negexp is not none else 'no' }}minatomstdngxp
|
||||
{#--#}-{{ hp.sphere_grow_reg_coeff }}sphgrow
|
||||
{#--#}
|
||||
{%- endif -%}
|
||||
{#--#}-{{ int(hp.dropout_percent*10) }}mdrop
|
||||
{#--#}-{{ hp.normalization or "nonorm" }}
|
||||
{#--#}-{{ hp.grad_reg }}
|
||||
{#--#}{% if "multi_view" in hp.grad_reg %}
|
||||
{#- -#}-{{ int(hp.multi_view_reg_coeff*10) }}dmv
|
||||
{#--#}{%- endif -%}
|
||||
{#--#}-{{ "concat" if do_condition else "nocond" }}
|
||||
{#--#}-{{ int(hp.warmup_epochs*100) }}cwu{{ int(hp.lr_coeff*100) }}clr{{ int(hp.test_val_split_frac*100) }}tvs
|
||||
{#--#}-{{ gen_run_uid(4) }} # select with --Oselect={{ index }}
|
||||
{#--#}
|
||||
{##}
|
||||
|
||||
{% endautoescape %}
|
||||
IntersectionFieldAutoDecoderModel:
|
||||
_extra: # used for easier introspection with jq
|
||||
dataset_item: {{ hp.dataset_item | to_json}}
|
||||
dataset_test_val_frac: {{ hp.test_val_split_frac }}
|
||||
select: {{ index }}
|
||||
|
||||
input_mode : {{ hp.input_mode }} # in {plucker, perp_foot, both}
|
||||
output_mode : {{ hp.output_mode }} # in {medial_sphere, orthogonal_plane}
|
||||
#latent_features : 256 # int
|
||||
#latent_features : 128 # int
|
||||
latent_features : 16 # int
|
||||
hidden_features : {{ hp.hidden_features }} # int
|
||||
hidden_layers : {{ hp.hidden_layers }} # int
|
||||
|
||||
improve_miss_grads : {{ bool(hp.improve_miss_grads) | to_json }}
|
||||
normalize_ray_dirs : {{ bool(hp.normalize_ray_dirs) | to_json }}
|
||||
|
||||
loss_intersection : {{ hp.intersection_coeff }}
|
||||
loss_intersection_l2 : 0
|
||||
loss_intersection_proj : 0
|
||||
loss_intersection_proj_l2 : 0
|
||||
|
||||
loss_normal_cossim : {{ hp.normal_coeff }} * EaseSin(85, 15)
|
||||
loss_normal_euclid : 0
|
||||
loss_normal_cossim_proj : 0
|
||||
loss_normal_euclid_proj : 0
|
||||
|
||||
{% if "multi_view" in hp.grad_reg %}
|
||||
loss_multi_view_reg : 0.1 * {{ hp.multi_view_reg_coeff }} * Linear(50)
|
||||
{% else %}
|
||||
loss_multi_view_reg : 0
|
||||
{% endif %}
|
||||
|
||||
{% if hp.output_mode == "orthogonal_plane" %}
|
||||
|
||||
loss_hit_cross_entropy : 1
|
||||
|
||||
{% elif hp.output_mode == "medial_sphere" %}
|
||||
|
||||
loss_hit_nodistance_l1 : 0
|
||||
loss_hit_nodistance_l2 : 100 * {{ hp.miss_distance_coeff }}
|
||||
loss_miss_distance_l1 : 0
|
||||
loss_miss_distance_l2 : 10 * {{ hp.miss_distance_coeff }}
|
||||
|
||||
loss_inscription_hits : {{ 0.4 * hp.loss_inscription }}
|
||||
loss_inscription_miss : 0
|
||||
loss_inscription_hits_l2 : 0
|
||||
loss_inscription_miss_l2 : {{ 6 * hp.loss_inscription }}
|
||||
|
||||
loss_sphere_grow_reg : 1e-6 * {{ hp.sphere_grow_reg_coeff }} # constant
|
||||
loss_atom_centroid_norm_std_reg: (0.09*(1-Linear(40)) + 0.01) * {{ 10**(-hp.atom_centroid_norm_std_reg_negexp) if hp.atom_centroid_norm_std_reg_negexp is not none else 0 }}
|
||||
|
||||
{% else %}{#endif hp.output_mode == "medial_sphere" #}
|
||||
THIS IS INVALID YAML
|
||||
{% endif %}
|
||||
|
||||
loss_embedding_norm : 0.01**2 * Linear(30, 0.1)
|
||||
|
||||
opt_learning_rate : {{ hp.lr_coeff }} * 10**(-4-0.5*EaseSin(170, 30)) # layernorm
|
||||
opt_warmup : {{ hp.warmup_epochs }}
|
||||
opt_weight_decay : 5e-6 # float
|
||||
|
||||
{% if hp.output_mode == "medial_sphere" %}
|
||||
|
||||
# MedialAtomNet:
|
||||
n_atoms : {{ hp.n_atoms }} # int
|
||||
{% if hp.geom_init %}
|
||||
final_init_wrr: [0.05, 0.6, 0.1]
|
||||
{% else %}
|
||||
final_init_wrr: null
|
||||
{% endif %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
|
||||
# FCBlock:
|
||||
normalization : {{ hp.normalization or "null" }} # in {null, layernorm, layernorm_na, weightnorm}
|
||||
nonlinearity : {{ hp.nonlinearity or "null" }} # in {null, relu, leaky_relu, silu, softplus, elu, selu, sine, sigmoid, tanh }
|
||||
{% set middle = 1 + hp.hidden_layers // 2 + (hp.hidden_layers % 2) %}{##}
|
||||
concat_skipped_layers : [{{ middle }}, -1]
|
||||
{% if do_condition %}
|
||||
concat_conditioned_layers : [0, {{ middle }}]
|
||||
{% else %}
|
||||
concat_conditioned_layers : []
|
||||
{% endif %}
|
||||
|
||||
# FCLayer:
|
||||
negative_slope : 0.01 # float
|
||||
omega_0 : {{ hp.omega }} # float
|
||||
residual_mode : null # in {null, identity}
|
||||
|
||||
{% endif %}{# -Oselect #}
|
||||
|
||||
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% set index = next(counter) %}
|
||||
# number of possible -Oselect: {{ index }}, from 0 to {{ index-1 }}
|
||||
# local: for select in {0..{{ index-1 }}}; do python ... -Omode={{ mode }} -Oselect=$select ... ; done
|
||||
# local: for select in {0..{{ index-1 }}}; do python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=$select -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu ; done
|
||||
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python ... -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID ...
|
||||
# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python -O {{ argv[0] }} model marf.yaml.j2 -Omode={{ mode }} -Oselect=\$SLURM_ARRAY_TASK_ID -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu --devices -1 --strategy ddp
|
||||
849
experiments/summary.py
Executable file
849
experiments/summary.py
Executable file
@@ -0,0 +1,849 @@
|
||||
#!/usr/bin/env python
|
||||
from concurrent.futures import ThreadPoolExecutor, Future, ProcessPoolExecutor
|
||||
from functools import partial
|
||||
from more_itertools import first, last, tail
|
||||
from munch import Munch, DefaultMunch, munchify, unmunchify
|
||||
from pathlib import Path
|
||||
from statistics import mean, StatisticsError
|
||||
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
||||
from typing import Iterable, Optional, Literal
|
||||
from math import isnan
|
||||
import json
|
||||
import stat
|
||||
import matplotlib
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
import os, os.path
|
||||
import re
|
||||
import shlex
|
||||
import time
|
||||
import itertools
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
import typer
|
||||
import warnings
|
||||
import yaml
|
||||
import tempfile
|
||||
|
||||
EXPERIMENTS = Path(__file__).resolve()
|
||||
LOGDIR = EXPERIMENTS / "logdir"
|
||||
TENSORBOARD = LOGDIR / "tensorboard"
|
||||
SLURM_LOGS = LOGDIR / "slurm_logs"
|
||||
CACHED_SUMMARIES = LOGDIR / "cached_summaries"
|
||||
COMPUTED_SCORES = LOGDIR / "computed_scores"
|
||||
|
||||
MISSING = object()
|
||||
|
||||
class SafeLoaderIgnoreUnknown(yaml.SafeLoader):
|
||||
def ignore_unknown(self, node):
|
||||
return None
|
||||
SafeLoaderIgnoreUnknown.add_constructor(None, SafeLoaderIgnoreUnknown.ignore_unknown)
|
||||
|
||||
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
|
||||
parts = (
|
||||
part.lower()
|
||||
for part in re.split(r'(?=[A-Z])', text)
|
||||
if part
|
||||
)
|
||||
if join_abbreviations: # this operation is not reversible
|
||||
parts = list(parts)
|
||||
if len(parts) > 1:
|
||||
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
|
||||
if len(a) == len(b) == 1:
|
||||
parts[i] = parts[i] + parts.pop(i+1)
|
||||
return sep.join(parts)
|
||||
|
||||
def flatten_dict(data: dict, key_mapper: callable = lambda x: x) -> dict:
|
||||
if not any(isinstance(val, dict) for val in data.values()):
|
||||
return data
|
||||
else:
|
||||
return {
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if not isinstance(v, dict)
|
||||
} | {
|
||||
f"{key_mapper(p)}/{k}":v
|
||||
for p,d in data.items()
|
||||
if isinstance(d, dict)
|
||||
for k,v in d.items()
|
||||
}
|
||||
|
||||
def parse_jsonl(data: str) -> Iterable[dict]:
|
||||
yield from map(json.loads, (line for line in data.splitlines() if line.strip()))
|
||||
|
||||
def read_jsonl(path: Path) -> Iterable[dict]:
|
||||
with path.open("r") as f:
|
||||
data = f.read()
|
||||
yield from parse_jsonl(data)
|
||||
|
||||
def get_experiment_paths(filter: str | None, assert_dumped = False) -> Iterable[Path]:
|
||||
for path in TENSORBOARD.iterdir():
|
||||
if filter is not None and not re.search(filter, path.name): continue
|
||||
if not path.is_dir(): continue
|
||||
|
||||
if not (path / "hparams.yaml").is_file():
|
||||
warnings.warn(f"Missing hparams: {path}")
|
||||
continue
|
||||
if not any(path.glob("events.out.tfevents.*")):
|
||||
warnings.warn(f"Missing tfevents: {path}")
|
||||
continue
|
||||
|
||||
if __debug__ and assert_dumped:
|
||||
assert (path / "scalars/epoch.json").is_file(), path
|
||||
assert (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json").is_file(), path
|
||||
assert (path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json").is_file(), path
|
||||
|
||||
yield path
|
||||
|
||||
def dump_pl_tensorboard_hparams(experiment: Path):
|
||||
with (experiment / "hparams.yaml").open() as f:
|
||||
hparams = yaml.load(f, Loader=SafeLoaderIgnoreUnknown)
|
||||
|
||||
shebang = None
|
||||
with (experiment / "config.yaml").open("w") as f:
|
||||
raw_yaml = hparams.get('_pickled_cli_args', {}).get('_raw_yaml', "").replace("\n\r", "\n")
|
||||
if raw_yaml.startswith("#!"): # preserve shebang
|
||||
shebang, _, raw_yaml = raw_yaml.partition("\n")
|
||||
f.write(f"{shebang}\n")
|
||||
f.write(f"# {' '.join(map(shlex.quote, hparams.get('_pickled_cli_args', {}).get('sys_argv', ['None'])))}\n\n")
|
||||
f.write(raw_yaml)
|
||||
if shebang is not None:
|
||||
os.chmod(experiment / "config.yaml", (experiment / "config.yaml").stat().st_mode | stat.S_IXUSR)
|
||||
print(experiment / "config.yaml", "written!", file=sys.stderr)
|
||||
|
||||
with (experiment / "environ.yaml").open("w") as f:
|
||||
yaml.safe_dump(hparams.get('_pickled_cli_args', {}).get('host', {}).get('environ'), f)
|
||||
print(experiment / "environ.yaml", "written!", file=sys.stderr)
|
||||
|
||||
with (experiment / "repo.patch").open("w") as f:
|
||||
f.write(hparams.get('_pickled_cli_args', {}).get('host', {}).get('vcs', "None"))
|
||||
print(experiment / "repo.patch", "written!", file=sys.stderr)
|
||||
|
||||
def dump_simple_tf_events_to_jsonl(output_dir: Path, *tf_files: Path):
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
import tensorboard.backend.event_processing.event_accumulator
|
||||
s, l = {}, [] # reused sentinels
|
||||
|
||||
#resource.setrlimit(resource.RLIMIT_NOFILE, (2**16,-1))
|
||||
file_handles = {}
|
||||
try:
|
||||
for tffile in tf_files:
|
||||
loader = tensorboard.backend.event_processing.event_file_loader.LegacyEventFileLoader(str(tffile))
|
||||
for event in loader.Load():
|
||||
for summary in MessageToDict(event).get("summary", s).get("value", l):
|
||||
if "simpleValue" in summary:
|
||||
tag = summary["tag"]
|
||||
if tag not in file_handles:
|
||||
fname = output_dir / f"{tag}.json"
|
||||
print(f"Opening {str(fname)!r}...", file=sys.stderr)
|
||||
fname.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_handles[tag] = fname.open("w") # ("a")
|
||||
val = summary["simpleValue"]
|
||||
data = json.dumps({
|
||||
"step" : event.step,
|
||||
"value" : float(val) if isinstance(val, str) else val,
|
||||
"wall_time" : event.wall_time,
|
||||
})
|
||||
file_handles[tag].write(f"{data}\n")
|
||||
finally:
|
||||
if file_handles:
|
||||
print("Closing json files...", file=sys.stderr)
|
||||
for k, v in file_handles.items():
|
||||
v.close()
|
||||
|
||||
|
||||
NO_FILTER = {
|
||||
"__uid",
|
||||
"_minutes",
|
||||
"_epochs",
|
||||
"_hp_nonlinearity",
|
||||
"_val_uloss_intersection",
|
||||
"_val_uloss_normal_cossim",
|
||||
"_val_uloss_intersection",
|
||||
}
|
||||
def filter_jsonl_columns(data: Iterable[dict | None], no_filter=NO_FILTER) -> list[dict]:
|
||||
def merge_siren_omega(data: dict) -> dict:
|
||||
return {
|
||||
key: (
|
||||
f"{val}-{data.get('hp_omega_0', 'ERROR')}"
|
||||
if (key.removeprefix("_"), val) == ("hp_nonlinearity", "sine") else
|
||||
val
|
||||
)
|
||||
for key, val in data.items()
|
||||
if key != "hp_omega_0"
|
||||
}
|
||||
|
||||
def remove_uninteresting_cols(rows: list[dict]) -> Iterable[dict]:
|
||||
unique_vals = {}
|
||||
def register_val(key, val):
|
||||
unique_vals.setdefault(key, set()).add(repr(val))
|
||||
return val
|
||||
|
||||
whitelisted = {
|
||||
key
|
||||
for row in rows
|
||||
for key, val in row.items()
|
||||
if register_val(key, val) and val not in ("None", "0", "0.0")
|
||||
}
|
||||
for key in unique_vals:
|
||||
for row in rows:
|
||||
if key not in row:
|
||||
unique_vals[key].add(MISSING)
|
||||
for key, vals in unique_vals.items():
|
||||
if key not in whitelisted: continue
|
||||
if len(vals) == 1:
|
||||
whitelisted.remove(key)
|
||||
|
||||
whitelisted.update(no_filter)
|
||||
|
||||
yield from (
|
||||
{
|
||||
key: val
|
||||
for key, val in row.items()
|
||||
if key in whitelisted
|
||||
}
|
||||
for row in rows
|
||||
)
|
||||
|
||||
def pessemize_types(rows: list[dict]) -> Iterable[dict]:
|
||||
types = {}
|
||||
order = (str, float, int, bool, tuple, type(None))
|
||||
for row in rows:
|
||||
for key, val in row.items():
|
||||
if isinstance(val, list): val = tuple(val)
|
||||
assert type(val) in order, (type(val), val)
|
||||
index = order.index(type(val))
|
||||
types[key] = min(types.get(key, 999), index)
|
||||
|
||||
yield from (
|
||||
{
|
||||
key: order[types[key]](val) if val is not None else None
|
||||
for key, val in row.items()
|
||||
}
|
||||
for row in rows
|
||||
)
|
||||
|
||||
data = (row for row in data if row is not None)
|
||||
data = map(partial(flatten_dict, key_mapper=camel_to_snake_case), data)
|
||||
data = map(merge_siren_omega, data)
|
||||
data = remove_uninteresting_cols(list(data))
|
||||
data = pessemize_types(list(data))
|
||||
|
||||
return data
|
||||
|
||||
PlotMode = Literal["stackplot", "lineplot"]
|
||||
|
||||
def plot_losses(experiments: list[Path], mode: PlotMode, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force=True):
|
||||
def get_losses(experiment: Path, training: bool = True, unscaled: bool = False) -> Iterable[Path]:
|
||||
if not training and unscaled:
|
||||
return experiment.glob("scalars/*.validation_step/unscaled_loss_*.json")
|
||||
elif not training and not unscaled:
|
||||
return experiment.glob("scalars/*.validation_step/loss_*.json")
|
||||
elif training and unscaled:
|
||||
return experiment.glob("scalars/*.training_step/unscaled_loss_*.json")
|
||||
elif training and not unscaled:
|
||||
return experiment.glob("scalars/*.training_step/loss_*.json")
|
||||
|
||||
print("Mapping colors...")
|
||||
configurations = [
|
||||
dict(unscaled=unscaled, training=training),
|
||||
] if not write else [
|
||||
dict(unscaled=False, training=False),
|
||||
dict(unscaled=False, training=True),
|
||||
dict(unscaled=True, training=False),
|
||||
dict(unscaled=True, training=True),
|
||||
]
|
||||
legends = set(
|
||||
f"""{
|
||||
loss.parent.name.split(".", 1)[0]
|
||||
}.{
|
||||
loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
|
||||
}"""
|
||||
for experiment in experiments
|
||||
for kw in configurations
|
||||
for loss in get_losses(experiment, **kw)
|
||||
)
|
||||
colormap = dict(zip(
|
||||
sorted(legends),
|
||||
itertools.cycle(mcolors.TABLEAU_COLORS),
|
||||
))
|
||||
|
||||
def mkplot(experiment: Path, training: bool = True, unscaled: bool = False) -> tuple[bool, str]:
|
||||
label = f"{'unscaled' if unscaled else 'scaled'} {'training' if training else 'validation'}"
|
||||
if write:
|
||||
old_savefig_fname = experiment / f"{label.replace(' ', '-')}-{mode}.png"
|
||||
savefig_fname = experiment / "plots" / f"{label.replace(' ', '-')}-{mode}.png"
|
||||
savefig_fname.parent.mkdir(exist_ok=True, parents=True)
|
||||
if old_savefig_fname.is_file():
|
||||
old_savefig_fname.rename(savefig_fname)
|
||||
if savefig_fname.is_file() and not force:
|
||||
return True, "savefig_fname already exists"
|
||||
|
||||
# Get and sort data
|
||||
losses = {}
|
||||
for loss in get_losses(experiment, training=training, unscaled=unscaled):
|
||||
model = loss.parent.name.split(".", 1)[0]
|
||||
name = loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
|
||||
losses[f"{model}.{name}"] = (loss, list(read_jsonl(loss)))
|
||||
losses = dict(sorted(losses.items())) # sort keys
|
||||
if not losses:
|
||||
return True, "no losses"
|
||||
|
||||
# unwrap
|
||||
steps = [i["step"] for i in first(losses.values())[1]]
|
||||
values = [
|
||||
[i["value"] if not isnan(i["value"]) else 0 for i in data]
|
||||
for name, (scalar, data) in losses.items()
|
||||
]
|
||||
|
||||
# normalize
|
||||
if mode == "stackplot":
|
||||
totals = list(map(sum, zip(*values)))
|
||||
values = [
|
||||
[i / t for i, t in zip(data, totals)]
|
||||
for data in values
|
||||
]
|
||||
|
||||
print(experiment.name, label)
|
||||
fig, ax = plt.subplots(figsize=(16, 12))
|
||||
|
||||
if mode == "stackplot":
|
||||
ax.stackplot(steps, values,
|
||||
colors = list(map(colormap.__getitem__, losses.keys())),
|
||||
labels = list(
|
||||
label.split(".", 1)[1].removeprefix("loss_")
|
||||
for label in losses.keys()
|
||||
),
|
||||
)
|
||||
ax.set_xlim(0, steps[-1])
|
||||
ax.set_ylim(0, 1)
|
||||
ax.invert_yaxis()
|
||||
|
||||
elif mode == "lineplot":
|
||||
for data, color, label in zip(
|
||||
values,
|
||||
map(colormap.__getitem__, losses.keys()),
|
||||
list(losses.keys()),
|
||||
):
|
||||
ax.plot(steps, data,
|
||||
color = color,
|
||||
label = label,
|
||||
)
|
||||
ax.set_xlim(0, steps[-1])
|
||||
|
||||
else:
|
||||
raise ValueError(f"{mode=}")
|
||||
|
||||
ax.legend()
|
||||
ax.set_title(f"{label} loss\n{experiment.name}")
|
||||
ax.set_xlabel("Step")
|
||||
ax.set_ylabel("loss%")
|
||||
|
||||
if mode == "stackplot":
|
||||
ax2 = make_axes_locatable(ax).append_axes("bottom", 0.8, pad=0.05, sharex=ax)
|
||||
ax2.stackplot( steps, totals )
|
||||
|
||||
for tl in ax.get_xticklabels(): tl.set_visible(False)
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
if write:
|
||||
fig.savefig(savefig_fname, dpi=300)
|
||||
print(savefig_fname)
|
||||
plt.close(fig)
|
||||
|
||||
return False, None
|
||||
|
||||
print("Plotting...")
|
||||
if write:
|
||||
matplotlib.use('agg') # fixes "WARNING: QApplication was not created in the main() thread."
|
||||
any_error = False
|
||||
if write:
|
||||
with ThreadPoolExecutor(max_workers=None) as pool:
|
||||
futures = [
|
||||
(experiment, pool.submit(mkplot, experiment, **kw))
|
||||
for experiment in experiments
|
||||
for kw in configurations
|
||||
]
|
||||
else:
|
||||
def mkfuture(item):
|
||||
f = Future()
|
||||
f.set_result(item)
|
||||
return f
|
||||
futures = [
|
||||
(experiment, mkfuture(mkplot(experiment, **kw)))
|
||||
for experiment in experiments
|
||||
for kw in configurations
|
||||
]
|
||||
|
||||
for experiment, future in futures:
|
||||
try:
|
||||
err, msg = future.result()
|
||||
except Exception:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
any_error = True
|
||||
continue
|
||||
if err:
|
||||
print(f"{msg}: {experiment.name}")
|
||||
any_error = True
|
||||
continue
|
||||
|
||||
if not any_error and not write: # show in main thread
|
||||
plt.show()
|
||||
elif not write:
|
||||
print("There were errors, will not show figure...", file=sys.stderr)
|
||||
|
||||
|
||||
|
||||
# =========
|
||||
|
||||
app = typer.Typer(no_args_is_help=True, add_completion=False)
|
||||
|
||||
@app.command(help="Dump simple tensorboard events to json and extract some pytorch lightning hparams")
|
||||
def tf_dump(tfevent_files: list[Path], j: int = typer.Option(1, "-j"), force: bool = False):
|
||||
# expand to all tfevents files (there may be more than one)
|
||||
tfevent_files = sorted(set([
|
||||
tffile
|
||||
for tffile in tfevent_files
|
||||
if tffile.name.startswith("events.out.tfevents.")
|
||||
] + [
|
||||
tffile
|
||||
for experiment_dir in tfevent_files
|
||||
if experiment_dir.is_dir()
|
||||
for tffile in experiment_dir.glob("events.out.tfevents.*")
|
||||
] + [
|
||||
tffile
|
||||
for hparam_file in tfevent_files
|
||||
if hparam_file.name in ("hparams.yaml", "config.yaml")
|
||||
for tffile in hparam_file.parent.glob("events.out.tfevents.*")
|
||||
]))
|
||||
|
||||
# filter already dumped
|
||||
if not force:
|
||||
tfevent_files = [
|
||||
tffile
|
||||
for tffile in tfevent_files
|
||||
if not (
|
||||
(tffile.parent / "scalars/epoch.json").is_file()
|
||||
and
|
||||
tffile.stat().st_mtime < (tffile.parent / "scalars/epoch.json").stat().st_mtime
|
||||
)
|
||||
]
|
||||
|
||||
if not tfevent_files:
|
||||
raise typer.BadParameter("Nothing to be done, consider --force")
|
||||
|
||||
jobs = {}
|
||||
for tffile in tfevent_files:
|
||||
if not tffile.is_file():
|
||||
print("ERROR: file not found:", tffile, file=sys.stderr)
|
||||
continue
|
||||
output_dir = tffile.parent / "scalars"
|
||||
jobs.setdefault(output_dir, []).append(tffile)
|
||||
with ProcessPoolExecutor() as p:
|
||||
for experiment in set(tffile.parent for tffile in tfevent_files):
|
||||
p.submit(dump_pl_tensorboard_hparams, experiment)
|
||||
for output_dir, tffiles in jobs.items():
|
||||
p.submit(dump_simple_tf_events_to_jsonl, output_dir, *tffiles)
|
||||
|
||||
@app.command(help="Propose experiment regexes")
|
||||
def propose(cmd: str = typer.Argument("summary"), null: bool = False):
|
||||
def get():
|
||||
for i in TENSORBOARD.iterdir():
|
||||
if not i.is_dir(): continue
|
||||
if not (i / "hparams.yaml").is_file(): continue
|
||||
prefix, name, *hparams, year, month, day, hhmm, uid = i.name.split("-")
|
||||
yield f"{name}.*-{year}-{month}-{day}"
|
||||
proposals = sorted(set(get()), key=lambda x: x.split(".*-", 1)[1])
|
||||
print("\n".join(
|
||||
f"{'>/dev/null ' if null else ''}{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
|
||||
for i in proposals
|
||||
))
|
||||
|
||||
@app.command("list", help="List used experiment regexes")
|
||||
def list_cached_summaries(cmd: str = typer.Argument("summary")):
|
||||
if not CACHED_SUMMARIES.is_dir():
|
||||
cached = []
|
||||
else:
|
||||
cached = [
|
||||
i.name.removesuffix(".jsonl")
|
||||
for i in CACHED_SUMMARIES.iterdir()
|
||||
if i.suffix == ".jsonl"
|
||||
if i.is_file() and i.stat().st_size
|
||||
]
|
||||
def order(key: str) -> list[str]:
|
||||
return re.sub(r'[^0-9\-]', '', key.split(".*")[-1]).strip("-").split("-") + [key]
|
||||
|
||||
print("\n".join(
|
||||
f"{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
|
||||
for i in sorted(cached, key=order)
|
||||
))
|
||||
|
||||
@app.command(help="Precompute the summary of a experiment regex")
|
||||
def compute_summary(filter: str, force: bool = False, dump: bool = False, no_cache: bool = False):
|
||||
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
|
||||
if cache.is_file() and cache.stat().st_size:
|
||||
if not force:
|
||||
raise FileExistsError(cache)
|
||||
|
||||
def mk_summary(path: Path) -> dict | None:
|
||||
cache = path / "train_summary.json"
|
||||
if cache.is_file() and cache.stat().st_size and cache.stat().st_mtime > (path/"scalars/epoch.json").stat().st_mtime:
|
||||
with cache.open() as f:
|
||||
return json.load(f)
|
||||
else:
|
||||
with (path / "hparams.yaml").open() as f:
|
||||
hparams = munchify(yaml.load(f, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
|
||||
config = hparams._pickled_cli_args._raw_yaml
|
||||
config = munchify(yaml.load(config, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
|
||||
|
||||
try:
|
||||
train_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json"))
|
||||
val_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json"))
|
||||
except:
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
return None
|
||||
|
||||
out = Munch()
|
||||
out.uid = path.name.rsplit("-", 1)[-1]
|
||||
out.name = path.name
|
||||
out.date = "-".join(path.name.split("-")[-5:-1])
|
||||
out.epochs = int(last(read_jsonl(path / "scalars/epoch.json"))["value"])
|
||||
out.steps = val_loss[-1]["step"]
|
||||
out.gpu = hparams._pickled_cli_args.host.gpus[1][1]
|
||||
|
||||
if val_loss[-1]["wall_time"] - val_loss[0]["wall_time"] > 0:
|
||||
out.batches_per_second = val_loss[-1]["step"] / (val_loss[-1]["wall_time"] - val_loss[0]["wall_time"])
|
||||
else:
|
||||
out.batches_per_second = 0
|
||||
|
||||
out.minutes = (val_loss[-1]["wall_time"] - train_loss[0]["wall_time"]) / 60
|
||||
|
||||
if (path / "scalars/PsutilMonitor/gpu.00.memory.used.json").is_file():
|
||||
max(i["value"] for i in read_jsonl(path / "scalars/PsutilMonitor/gpu.00.memory.used.json"))
|
||||
|
||||
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step").glob("*.json"):
|
||||
if not metric_path.is_file() or not metric_path.stat().st_size: continue
|
||||
|
||||
metric_name = metric_path.name.removesuffix(".json")
|
||||
metric_data = read_jsonl(metric_path)
|
||||
try:
|
||||
out[f"val_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
|
||||
except StatisticsError:
|
||||
out[f"val_{metric_name}"] = float('nan')
|
||||
|
||||
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.training_step").glob("*.json"):
|
||||
if not any(i in metric_path.name for i in ("miss_radius_grad", "sphere_center_grad", "loss_tangential_reg", "multi_view")): continue
|
||||
if not metric_path.is_file() or not metric_path.stat().st_size: continue
|
||||
|
||||
metric_name = metric_path.name.removesuffix(".json")
|
||||
metric_data = read_jsonl(metric_path)
|
||||
try:
|
||||
out[f"train_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
|
||||
except StatisticsError:
|
||||
out[f"train_{metric_name}"] = float('nan')
|
||||
|
||||
out.hostname = hparams._pickled_cli_args.host.hostname
|
||||
|
||||
for key, val in config.IntersectionFieldAutoDecoderModel.items():
|
||||
if isinstance(val, dict):
|
||||
out.update({f"hp_{key}_{k}": v for k, v in val.items()})
|
||||
elif isinstance(val, float | int | str | bool | None):
|
||||
out[f"hp_{key}"] = val
|
||||
|
||||
with cache.open("w") as f:
|
||||
json.dump(unmunchify(out), f)
|
||||
|
||||
return dict(out)
|
||||
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No matching experiment")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments) # force=force_dump)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
# does literally nothing, thanks GIL
|
||||
with ThreadPoolExecutor() as p:
|
||||
results = list(p.map(mk_summary, experiments))
|
||||
|
||||
if any(result is None for result in results):
|
||||
if all(result is None for result in results):
|
||||
print("No summary succeeded", file=sys.stderr)
|
||||
raise typer.Exit(exit_code=1)
|
||||
warnings.warn("Some summaries failed:\n" + "\n".join(
|
||||
str(experiment)
|
||||
for result, experiment in zip(results, experiments)
|
||||
if result is None
|
||||
))
|
||||
|
||||
summaries = "\n".join( map(json.dumps, results) )
|
||||
if not no_cache:
|
||||
cache.parent.mkdir(parents=True, exist_ok=True)
|
||||
with cache.open("w") as f:
|
||||
f.write(summaries)
|
||||
return summaries
|
||||
|
||||
@app.command(help="Show the summary of a experiment regex, precompute it if needed")
|
||||
def summary(filter: Optional[str] = typer.Argument(None), force: bool = False, dump: bool = False, all: bool = False):
|
||||
if filter is None:
|
||||
return list_cached_summaries("summary")
|
||||
|
||||
def key_mangler(key: str) -> str:
|
||||
for pattern, sub in (
|
||||
(r'^val_unscaled_loss_', r'val_uloss_'),
|
||||
(r'^train_unscaled_loss_', r'train_uloss_'),
|
||||
(r'^val_loss_', r'val_sloss_'),
|
||||
(r'^train_loss_', r'train_sloss_'),
|
||||
):
|
||||
key = re.sub(pattern, sub, key)
|
||||
|
||||
return key
|
||||
|
||||
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
|
||||
if force or not (cache.is_file() and cache.stat().st_size):
|
||||
compute_summary(filter, force=force, dump=dump)
|
||||
assert cache.is_file() and cache.stat().st_size, (cache, cache.stat())
|
||||
|
||||
if os.isatty(0) and os.isatty(1) and shutil.which("vd"):
|
||||
rows = read_jsonl(cache)
|
||||
rows = ({key_mangler(k): v for k, v in row.items()} if row is not None else None for row in rows)
|
||||
if not all:
|
||||
rows = filter_jsonl_columns(rows)
|
||||
rows = ({k: v for k, v in row.items() if not k.startswith(("val_sloss_", "train_sloss_"))} for row in rows)
|
||||
data = "\n".join(map(json.dumps, rows))
|
||||
subprocess.run(["vd",
|
||||
#"--play", EXPERIMENTS / "set-key-columns.vd",
|
||||
"-f", "jsonl"
|
||||
], input=data, text=True, check=True)
|
||||
else:
|
||||
with cache.open() as f:
|
||||
print(f.read())
|
||||
|
||||
@app.command(help="Filter uninteresting keys from jsonl stdin")
|
||||
def filter_cols():
|
||||
rows = map(json.loads, (line for line in sys.stdin.readlines() if line.strip()))
|
||||
rows = filter_jsonl_columns(rows)
|
||||
print(*map(json.dumps, rows), sep="\n")
|
||||
|
||||
@app.command(help="Run a command for each experiment matched by experiment regex")
|
||||
def exec(filter: str, cmd: list[str], j: int = typer.Option(1, "-j"), dumped: bool = False, undumped: bool = False):
|
||||
# inspired by fd / gnu parallel
|
||||
def populate_cmd(experiment: Path, cmd: Iterable[str]) -> Iterable[str]:
|
||||
any = False
|
||||
for i in cmd:
|
||||
if i == "{}":
|
||||
any = True
|
||||
yield str(experiment / "hparams.yaml")
|
||||
elif i == "{//}":
|
||||
any = True
|
||||
yield str(experiment)
|
||||
else:
|
||||
yield i
|
||||
if not any:
|
||||
yield str(experiment / "hparams.yaml")
|
||||
|
||||
with ThreadPoolExecutor(max_workers=j or None) as p:
|
||||
results = p.map(subprocess.run, (
|
||||
list(populate_cmd(experiment, cmd))
|
||||
for experiment in get_experiment_paths(filter)
|
||||
if not dumped or (experiment / "scalars/epoch.json").is_file()
|
||||
if not undumped or not (experiment / "scalars/epoch.json").is_file()
|
||||
))
|
||||
|
||||
if any(i.returncode for i in results):
|
||||
return typer.Exit(1)
|
||||
|
||||
@app.command(help="Show stackplot of experiment loss")
|
||||
def stackplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
plot_losses(experiments,
|
||||
mode = "stackplot",
|
||||
write = write,
|
||||
dump = dump,
|
||||
training = training,
|
||||
unscaled = unscaled,
|
||||
force = force,
|
||||
)
|
||||
|
||||
@app.command(help="Show stackplot of experiment loss")
|
||||
def lineplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
plot_losses(experiments,
|
||||
mode = "lineplot",
|
||||
write = write,
|
||||
dump = dump,
|
||||
training = training,
|
||||
unscaled = unscaled,
|
||||
force = force,
|
||||
)
|
||||
|
||||
@app.command(help="Open tensorboard for the experiments matching the regex")
|
||||
def tensorboard(filter: Optional[str] = typer.Argument(None), watch: bool = False):
|
||||
if filter is None:
|
||||
return list_cached_summaries("tensorboard")
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=False))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
|
||||
with tempfile.TemporaryDirectory(suffix=f"ifield-{filter}") as d:
|
||||
treefarm = Path(d)
|
||||
with ThreadPoolExecutor(max_workers=2) as p:
|
||||
for experiment in experiments:
|
||||
(treefarm / experiment.name).symlink_to(experiment)
|
||||
|
||||
cmd = ["tensorboard", "--logdir", d]
|
||||
print("+", *map(shlex.quote, cmd), file=sys.stderr)
|
||||
tensorboard = p.submit(subprocess.run, cmd, check=True)
|
||||
if not watch:
|
||||
tensorboard.result()
|
||||
|
||||
else:
|
||||
all_experiments = set(get_experiment_paths(None, assert_dumped=False))
|
||||
while not tensorboard.done():
|
||||
time.sleep(10)
|
||||
new_experiments = set(get_experiment_paths(None, assert_dumped=False)) - all_experiments
|
||||
if new_experiments:
|
||||
for experiment in new_experiments:
|
||||
print(f"Adding {experiment.name!r}...", file=sys.stderr)
|
||||
(treefarm / experiment.name).symlink_to(experiment)
|
||||
all_experiments.update(new_experiments)
|
||||
|
||||
@app.command(help="Compute evaluation metrics")
|
||||
def metrics(filter: Optional[str] = typer.Argument(None), dump: bool = False, dry: bool = False, prefix: Optional[str] = typer.Option(None), derive: bool = False, each: bool = False, no_total: bool = False):
|
||||
if filter is None:
|
||||
return list_cached_summaries("metrics --derive")
|
||||
experiments = list(get_experiment_paths(filter, assert_dumped=False))
|
||||
if not experiments:
|
||||
raise typer.BadParameter("No match")
|
||||
if dump:
|
||||
try:
|
||||
tf_dump(experiments)
|
||||
except typer.BadParameter:
|
||||
pass
|
||||
|
||||
def run(*cmd):
|
||||
if prefix is not None:
|
||||
cmd = [*shlex.split(prefix), *cmd]
|
||||
if dry:
|
||||
print(*map(shlex.quote, map(str, cmd)))
|
||||
else:
|
||||
print("+", *map(shlex.quote, map(str, cmd)))
|
||||
subprocess.run(cmd)
|
||||
|
||||
for experiment in experiments:
|
||||
if no_total: continue
|
||||
if not (experiment / "compute-scores/metrics.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics.json",
|
||||
"--transpose",
|
||||
)
|
||||
if not (experiment / "compute-scores/metrics-last.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics-last.json",
|
||||
"--transpose",
|
||||
)
|
||||
if "2prif-" not in experiment.name: continue
|
||||
if not (experiment / "compute-scores/metrics-sans_outliers.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics-sans_outliers.json",
|
||||
"--transpose", "--filter-outliers"
|
||||
)
|
||||
if not (experiment / "compute-scores/metrics-last-sans_outliers.json").is_file():
|
||||
run(
|
||||
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
|
||||
"compute-scores", experiment / "compute-scores/metrics-last-sans_outliers.json",
|
||||
"--transpose", "--filter-outliers"
|
||||
)
|
||||
|
||||
if dry: return
|
||||
if prefix is not None:
|
||||
print("prefix was used, assuming a job scheduler was used, will not print scores.", file=sys.stderr)
|
||||
return
|
||||
|
||||
metrics = [
|
||||
*(experiment / "compute-scores/metrics.json" for experiment in experiments),
|
||||
*(experiment / "compute-scores/metrics-last.json" for experiment in experiments),
|
||||
*(experiment / "compute-scores/metrics-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
|
||||
*(experiment / "compute-scores/metrics-last-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
|
||||
]
|
||||
if not no_total:
|
||||
assert all(metric.exists() for metric in metrics)
|
||||
else:
|
||||
metrics = (metric for metric in metrics if metric.exists())
|
||||
|
||||
out = []
|
||||
for metric in metrics:
|
||||
experiment = metric.parent.parent.name
|
||||
is_last = metric.name in ("metrics-last.json", "metrics-last-sans_outliers.json")
|
||||
with metric.open() as f:
|
||||
data = json.load(f)
|
||||
|
||||
if derive:
|
||||
derived = {}
|
||||
objs = [i for i in data.keys() if i != "_hparams"]
|
||||
for obj in (objs if each else []) + [None]:
|
||||
if obj is None:
|
||||
d = DefaultMunch(0)
|
||||
for obj in objs:
|
||||
for k, v in data[obj].items():
|
||||
d[k] += v
|
||||
obj = "_all_"
|
||||
n_cd = data["_hparams"]["n_cd"] * len(objs)
|
||||
n_emd = data["_hparams"]["n_emd"] * len(objs)
|
||||
else:
|
||||
d = munchify(data[obj])
|
||||
n_cd = data["_hparams"]["n_cd"]
|
||||
n_emd = data["_hparams"]["n_emd"]
|
||||
|
||||
precision = d.TP / (d.TP + d.FP)
|
||||
recall = d.TP / (d.TP + d.FN)
|
||||
derived[obj] = dict(
|
||||
filtered = d.n_outliers / d.n if "n_outliers" in d else None,
|
||||
iou = d.TP / (d.TP + d.FN + d.FP),
|
||||
precision = precision,
|
||||
recall = recall,
|
||||
f_score = 2 * (precision * recall) / (precision + recall),
|
||||
cd = d.cd_dist / n_cd,
|
||||
emd = d.emd / n_emd,
|
||||
cos_med = 1 - (d.cd_cos_med / n_cd) if "cd_cos_med" in d else None,
|
||||
cos_jac = 1 - (d.cd_cos_jac / n_cd),
|
||||
)
|
||||
data = derived if each else derived["_all_"]
|
||||
|
||||
data["uid"] = experiment.rsplit("-", 1)[-1]
|
||||
data["experiment_name"] = experiment
|
||||
data["is_last"] = is_last
|
||||
|
||||
out.append(json.dumps(data))
|
||||
|
||||
if derive and not each and os.isatty(0) and os.isatty(1) and shutil.which("vd"):
|
||||
subprocess.run(["vd", "-f", "jsonl"], input="\n".join(out), text=True, check=True)
|
||||
else:
|
||||
print("\n".join(out))
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
Reference in New Issue
Block a user