2025-01-09 15:43:11 +01:00

850 lines
33 KiB
Python
Executable File

#!/usr/bin/env python
from concurrent.futures import ThreadPoolExecutor, Future, ProcessPoolExecutor
from functools import partial
from more_itertools import first, last, tail
from munch import Munch, DefaultMunch, munchify, unmunchify
from pathlib import Path
from statistics import mean, StatisticsError
from mpl_toolkits.axes_grid1 import make_axes_locatable
from typing import Iterable, Optional, Literal
from math import isnan
import json
import stat
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import os, os.path
import re
import shlex
import time
import itertools
import shutil
import subprocess
import sys
import traceback
import typer
import warnings
import yaml
import tempfile
EXPERIMENTS = Path(__file__).resolve()
LOGDIR = EXPERIMENTS / "logdir"
TENSORBOARD = LOGDIR / "tensorboard"
SLURM_LOGS = LOGDIR / "slurm_logs"
CACHED_SUMMARIES = LOGDIR / "cached_summaries"
COMPUTED_SCORES = LOGDIR / "computed_scores"
MISSING = object()
class SafeLoaderIgnoreUnknown(yaml.SafeLoader):
def ignore_unknown(self, node):
return None
SafeLoaderIgnoreUnknown.add_constructor(None, SafeLoaderIgnoreUnknown.ignore_unknown)
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
parts = (
part.lower()
for part in re.split(r'(?=[A-Z])', text)
if part
)
if join_abbreviations: # this operation is not reversible
parts = list(parts)
if len(parts) > 1:
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
if len(a) == len(b) == 1:
parts[i] = parts[i] + parts.pop(i+1)
return sep.join(parts)
def flatten_dict(data: dict, key_mapper: callable = lambda x: x) -> dict:
if not any(isinstance(val, dict) for val in data.values()):
return data
else:
return {
k: v
for k, v in data.items()
if not isinstance(v, dict)
} | {
f"{key_mapper(p)}/{k}":v
for p,d in data.items()
if isinstance(d, dict)
for k,v in d.items()
}
def parse_jsonl(data: str) -> Iterable[dict]:
yield from map(json.loads, (line for line in data.splitlines() if line.strip()))
def read_jsonl(path: Path) -> Iterable[dict]:
with path.open("r") as f:
data = f.read()
yield from parse_jsonl(data)
def get_experiment_paths(filter: str | None, assert_dumped = False) -> Iterable[Path]:
for path in TENSORBOARD.iterdir():
if filter is not None and not re.search(filter, path.name): continue
if not path.is_dir(): continue
if not (path / "hparams.yaml").is_file():
warnings.warn(f"Missing hparams: {path}")
continue
if not any(path.glob("events.out.tfevents.*")):
warnings.warn(f"Missing tfevents: {path}")
continue
if __debug__ and assert_dumped:
assert (path / "scalars/epoch.json").is_file(), path
assert (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json").is_file(), path
assert (path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json").is_file(), path
yield path
def dump_pl_tensorboard_hparams(experiment: Path):
with (experiment / "hparams.yaml").open() as f:
hparams = yaml.load(f, Loader=SafeLoaderIgnoreUnknown)
shebang = None
with (experiment / "config.yaml").open("w") as f:
raw_yaml = hparams.get('_pickled_cli_args', {}).get('_raw_yaml', "").replace("\n\r", "\n")
if raw_yaml.startswith("#!"): # preserve shebang
shebang, _, raw_yaml = raw_yaml.partition("\n")
f.write(f"{shebang}\n")
f.write(f"# {' '.join(map(shlex.quote, hparams.get('_pickled_cli_args', {}).get('sys_argv', ['None'])))}\n\n")
f.write(raw_yaml)
if shebang is not None:
os.chmod(experiment / "config.yaml", (experiment / "config.yaml").stat().st_mode | stat.S_IXUSR)
print(experiment / "config.yaml", "written!", file=sys.stderr)
with (experiment / "environ.yaml").open("w") as f:
yaml.safe_dump(hparams.get('_pickled_cli_args', {}).get('host', {}).get('environ'), f)
print(experiment / "environ.yaml", "written!", file=sys.stderr)
with (experiment / "repo.patch").open("w") as f:
f.write(hparams.get('_pickled_cli_args', {}).get('host', {}).get('vcs', "None"))
print(experiment / "repo.patch", "written!", file=sys.stderr)
def dump_simple_tf_events_to_jsonl(output_dir: Path, *tf_files: Path):
from google.protobuf.json_format import MessageToDict
import tensorboard.backend.event_processing.event_accumulator
s, l = {}, [] # reused sentinels
#resource.setrlimit(resource.RLIMIT_NOFILE, (2**16,-1))
file_handles = {}
try:
for tffile in tf_files:
loader = tensorboard.backend.event_processing.event_file_loader.LegacyEventFileLoader(str(tffile))
for event in loader.Load():
for summary in MessageToDict(event).get("summary", s).get("value", l):
if "simpleValue" in summary:
tag = summary["tag"]
if tag not in file_handles:
fname = output_dir / f"{tag}.json"
print(f"Opening {str(fname)!r}...", file=sys.stderr)
fname.parent.mkdir(parents=True, exist_ok=True)
file_handles[tag] = fname.open("w") # ("a")
val = summary["simpleValue"]
data = json.dumps({
"step" : event.step,
"value" : float(val) if isinstance(val, str) else val,
"wall_time" : event.wall_time,
})
file_handles[tag].write(f"{data}\n")
finally:
if file_handles:
print("Closing json files...", file=sys.stderr)
for k, v in file_handles.items():
v.close()
NO_FILTER = {
"__uid",
"_minutes",
"_epochs",
"_hp_nonlinearity",
"_val_uloss_intersection",
"_val_uloss_normal_cossim",
"_val_uloss_intersection",
}
def filter_jsonl_columns(data: Iterable[dict | None], no_filter=NO_FILTER) -> list[dict]:
def merge_siren_omega(data: dict) -> dict:
return {
key: (
f"{val}-{data.get('hp_omega_0', 'ERROR')}"
if (key.removeprefix("_"), val) == ("hp_nonlinearity", "sine") else
val
)
for key, val in data.items()
if key != "hp_omega_0"
}
def remove_uninteresting_cols(rows: list[dict]) -> Iterable[dict]:
unique_vals = {}
def register_val(key, val):
unique_vals.setdefault(key, set()).add(repr(val))
return val
whitelisted = {
key
for row in rows
for key, val in row.items()
if register_val(key, val) and val not in ("None", "0", "0.0")
}
for key in unique_vals:
for row in rows:
if key not in row:
unique_vals[key].add(MISSING)
for key, vals in unique_vals.items():
if key not in whitelisted: continue
if len(vals) == 1:
whitelisted.remove(key)
whitelisted.update(no_filter)
yield from (
{
key: val
for key, val in row.items()
if key in whitelisted
}
for row in rows
)
def pessemize_types(rows: list[dict]) -> Iterable[dict]:
types = {}
order = (str, float, int, bool, tuple, type(None))
for row in rows:
for key, val in row.items():
if isinstance(val, list): val = tuple(val)
assert type(val) in order, (type(val), val)
index = order.index(type(val))
types[key] = min(types.get(key, 999), index)
yield from (
{
key: order[types[key]](val) if val is not None else None
for key, val in row.items()
}
for row in rows
)
data = (row for row in data if row is not None)
data = map(partial(flatten_dict, key_mapper=camel_to_snake_case), data)
data = map(merge_siren_omega, data)
data = remove_uninteresting_cols(list(data))
data = pessemize_types(list(data))
return data
PlotMode = Literal["stackplot", "lineplot"]
def plot_losses(experiments: list[Path], mode: PlotMode, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force=True):
def get_losses(experiment: Path, training: bool = True, unscaled: bool = False) -> Iterable[Path]:
if not training and unscaled:
return experiment.glob("scalars/*.validation_step/unscaled_loss_*.json")
elif not training and not unscaled:
return experiment.glob("scalars/*.validation_step/loss_*.json")
elif training and unscaled:
return experiment.glob("scalars/*.training_step/unscaled_loss_*.json")
elif training and not unscaled:
return experiment.glob("scalars/*.training_step/loss_*.json")
print("Mapping colors...")
configurations = [
dict(unscaled=unscaled, training=training),
] if not write else [
dict(unscaled=False, training=False),
dict(unscaled=False, training=True),
dict(unscaled=True, training=False),
dict(unscaled=True, training=True),
]
legends = set(
f"""{
loss.parent.name.split(".", 1)[0]
}.{
loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
}"""
for experiment in experiments
for kw in configurations
for loss in get_losses(experiment, **kw)
)
colormap = dict(zip(
sorted(legends),
itertools.cycle(mcolors.TABLEAU_COLORS),
))
def mkplot(experiment: Path, training: bool = True, unscaled: bool = False) -> tuple[bool, str]:
label = f"{'unscaled' if unscaled else 'scaled'} {'training' if training else 'validation'}"
if write:
old_savefig_fname = experiment / f"{label.replace(' ', '-')}-{mode}.png"
savefig_fname = experiment / "plots" / f"{label.replace(' ', '-')}-{mode}.png"
savefig_fname.parent.mkdir(exist_ok=True, parents=True)
if old_savefig_fname.is_file():
old_savefig_fname.rename(savefig_fname)
if savefig_fname.is_file() and not force:
return True, "savefig_fname already exists"
# Get and sort data
losses = {}
for loss in get_losses(experiment, training=training, unscaled=unscaled):
model = loss.parent.name.split(".", 1)[0]
name = loss.name.removesuffix(loss.suffix).removeprefix("unscaled_")
losses[f"{model}.{name}"] = (loss, list(read_jsonl(loss)))
losses = dict(sorted(losses.items())) # sort keys
if not losses:
return True, "no losses"
# unwrap
steps = [i["step"] for i in first(losses.values())[1]]
values = [
[i["value"] if not isnan(i["value"]) else 0 for i in data]
for name, (scalar, data) in losses.items()
]
# normalize
if mode == "stackplot":
totals = list(map(sum, zip(*values)))
values = [
[i / t for i, t in zip(data, totals)]
for data in values
]
print(experiment.name, label)
fig, ax = plt.subplots(figsize=(16, 12))
if mode == "stackplot":
ax.stackplot(steps, values,
colors = list(map(colormap.__getitem__, losses.keys())),
labels = list(
label.split(".", 1)[1].removeprefix("loss_")
for label in losses.keys()
),
)
ax.set_xlim(0, steps[-1])
ax.set_ylim(0, 1)
ax.invert_yaxis()
elif mode == "lineplot":
for data, color, label in zip(
values,
map(colormap.__getitem__, losses.keys()),
list(losses.keys()),
):
ax.plot(steps, data,
color = color,
label = label,
)
ax.set_xlim(0, steps[-1])
else:
raise ValueError(f"{mode=}")
ax.legend()
ax.set_title(f"{label} loss\n{experiment.name}")
ax.set_xlabel("Step")
ax.set_ylabel("loss%")
if mode == "stackplot":
ax2 = make_axes_locatable(ax).append_axes("bottom", 0.8, pad=0.05, sharex=ax)
ax2.stackplot( steps, totals )
for tl in ax.get_xticklabels(): tl.set_visible(False)
fig.tight_layout()
if write:
fig.savefig(savefig_fname, dpi=300)
print(savefig_fname)
plt.close(fig)
return False, None
print("Plotting...")
if write:
matplotlib.use('agg') # fixes "WARNING: QApplication was not created in the main() thread."
any_error = False
if write:
with ThreadPoolExecutor(max_workers=None) as pool:
futures = [
(experiment, pool.submit(mkplot, experiment, **kw))
for experiment in experiments
for kw in configurations
]
else:
def mkfuture(item):
f = Future()
f.set_result(item)
return f
futures = [
(experiment, mkfuture(mkplot(experiment, **kw)))
for experiment in experiments
for kw in configurations
]
for experiment, future in futures:
try:
err, msg = future.result()
except Exception:
traceback.print_exc(file=sys.stderr)
any_error = True
continue
if err:
print(f"{msg}: {experiment.name}")
any_error = True
continue
if not any_error and not write: # show in main thread
plt.show()
elif not write:
print("There were errors, will not show figure...", file=sys.stderr)
# =========
app = typer.Typer(no_args_is_help=True, add_completion=False)
@app.command(help="Dump simple tensorboard events to json and extract some pytorch lightning hparams")
def tf_dump(tfevent_files: list[Path], j: int = typer.Option(1, "-j"), force: bool = False):
# expand to all tfevents files (there may be more than one)
tfevent_files = sorted(set([
tffile
for tffile in tfevent_files
if tffile.name.startswith("events.out.tfevents.")
] + [
tffile
for experiment_dir in tfevent_files
if experiment_dir.is_dir()
for tffile in experiment_dir.glob("events.out.tfevents.*")
] + [
tffile
for hparam_file in tfevent_files
if hparam_file.name in ("hparams.yaml", "config.yaml")
for tffile in hparam_file.parent.glob("events.out.tfevents.*")
]))
# filter already dumped
if not force:
tfevent_files = [
tffile
for tffile in tfevent_files
if not (
(tffile.parent / "scalars/epoch.json").is_file()
and
tffile.stat().st_mtime < (tffile.parent / "scalars/epoch.json").stat().st_mtime
)
]
if not tfevent_files:
raise typer.BadParameter("Nothing to be done, consider --force")
jobs = {}
for tffile in tfevent_files:
if not tffile.is_file():
print("ERROR: file not found:", tffile, file=sys.stderr)
continue
output_dir = tffile.parent / "scalars"
jobs.setdefault(output_dir, []).append(tffile)
with ProcessPoolExecutor() as p:
for experiment in set(tffile.parent for tffile in tfevent_files):
p.submit(dump_pl_tensorboard_hparams, experiment)
for output_dir, tffiles in jobs.items():
p.submit(dump_simple_tf_events_to_jsonl, output_dir, *tffiles)
@app.command(help="Propose experiment regexes")
def propose(cmd: str = typer.Argument("summary"), null: bool = False):
def get():
for i in TENSORBOARD.iterdir():
if not i.is_dir(): continue
if not (i / "hparams.yaml").is_file(): continue
prefix, name, *hparams, year, month, day, hhmm, uid = i.name.split("-")
yield f"{name}.*-{year}-{month}-{day}"
proposals = sorted(set(get()), key=lambda x: x.split(".*-", 1)[1])
print("\n".join(
f"{'>/dev/null ' if null else ''}{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
for i in proposals
))
@app.command("list", help="List used experiment regexes")
def list_cached_summaries(cmd: str = typer.Argument("summary")):
if not CACHED_SUMMARIES.is_dir():
cached = []
else:
cached = [
i.name.removesuffix(".jsonl")
for i in CACHED_SUMMARIES.iterdir()
if i.suffix == ".jsonl"
if i.is_file() and i.stat().st_size
]
def order(key: str) -> list[str]:
return re.sub(r'[^0-9\-]', '', key.split(".*")[-1]).strip("-").split("-") + [key]
print("\n".join(
f"{sys.argv[0]} {cmd or 'summary'} {shlex.quote(i)}"
for i in sorted(cached, key=order)
))
@app.command(help="Precompute the summary of a experiment regex")
def compute_summary(filter: str, force: bool = False, dump: bool = False, no_cache: bool = False):
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
if cache.is_file() and cache.stat().st_size:
if not force:
raise FileExistsError(cache)
def mk_summary(path: Path) -> dict | None:
cache = path / "train_summary.json"
if cache.is_file() and cache.stat().st_size and cache.stat().st_mtime > (path/"scalars/epoch.json").stat().st_mtime:
with cache.open() as f:
return json.load(f)
else:
with (path / "hparams.yaml").open() as f:
hparams = munchify(yaml.load(f, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
config = hparams._pickled_cli_args._raw_yaml
config = munchify(yaml.load(config, Loader=SafeLoaderIgnoreUnknown), factory=partial(DefaultMunch, None))
try:
train_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.training_step/loss.json"))
val_loss = list(read_jsonl(path / "scalars/IntersectionFieldAutoDecoderModel.validation_step/loss.json"))
except:
traceback.print_exc(file=sys.stderr)
return None
out = Munch()
out.uid = path.name.rsplit("-", 1)[-1]
out.name = path.name
out.date = "-".join(path.name.split("-")[-5:-1])
out.epochs = int(last(read_jsonl(path / "scalars/epoch.json"))["value"])
out.steps = val_loss[-1]["step"]
out.gpu = hparams._pickled_cli_args.host.gpus[1][1]
if val_loss[-1]["wall_time"] - val_loss[0]["wall_time"] > 0:
out.batches_per_second = val_loss[-1]["step"] / (val_loss[-1]["wall_time"] - val_loss[0]["wall_time"])
else:
out.batches_per_second = 0
out.minutes = (val_loss[-1]["wall_time"] - train_loss[0]["wall_time"]) / 60
if (path / "scalars/PsutilMonitor/gpu.00.memory.used.json").is_file():
max(i["value"] for i in read_jsonl(path / "scalars/PsutilMonitor/gpu.00.memory.used.json"))
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.validation_step").glob("*.json"):
if not metric_path.is_file() or not metric_path.stat().st_size: continue
metric_name = metric_path.name.removesuffix(".json")
metric_data = read_jsonl(metric_path)
try:
out[f"val_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
except StatisticsError:
out[f"val_{metric_name}"] = float('nan')
for metric_path in (path / "scalars/IntersectionFieldAutoDecoderModel.training_step").glob("*.json"):
if not any(i in metric_path.name for i in ("miss_radius_grad", "sphere_center_grad", "loss_tangential_reg", "multi_view")): continue
if not metric_path.is_file() or not metric_path.stat().st_size: continue
metric_name = metric_path.name.removesuffix(".json")
metric_data = read_jsonl(metric_path)
try:
out[f"train_{metric_name}"] = mean(i["value"] for i in tail(5, metric_data))
except StatisticsError:
out[f"train_{metric_name}"] = float('nan')
out.hostname = hparams._pickled_cli_args.host.hostname
for key, val in config.IntersectionFieldAutoDecoderModel.items():
if isinstance(val, dict):
out.update({f"hp_{key}_{k}": v for k, v in val.items()})
elif isinstance(val, float | int | str | bool | None):
out[f"hp_{key}"] = val
with cache.open("w") as f:
json.dump(unmunchify(out), f)
return dict(out)
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
if not experiments:
raise typer.BadParameter("No matching experiment")
if dump:
try:
tf_dump(experiments) # force=force_dump)
except typer.BadParameter:
pass
# does literally nothing, thanks GIL
with ThreadPoolExecutor() as p:
results = list(p.map(mk_summary, experiments))
if any(result is None for result in results):
if all(result is None for result in results):
print("No summary succeeded", file=sys.stderr)
raise typer.Exit(exit_code=1)
warnings.warn("Some summaries failed:\n" + "\n".join(
str(experiment)
for result, experiment in zip(results, experiments)
if result is None
))
summaries = "\n".join( map(json.dumps, results) )
if not no_cache:
cache.parent.mkdir(parents=True, exist_ok=True)
with cache.open("w") as f:
f.write(summaries)
return summaries
@app.command(help="Show the summary of a experiment regex, precompute it if needed")
def summary(filter: Optional[str] = typer.Argument(None), force: bool = False, dump: bool = False, all: bool = False):
if filter is None:
return list_cached_summaries("summary")
def key_mangler(key: str) -> str:
for pattern, sub in (
(r'^val_unscaled_loss_', r'val_uloss_'),
(r'^train_unscaled_loss_', r'train_uloss_'),
(r'^val_loss_', r'val_sloss_'),
(r'^train_loss_', r'train_sloss_'),
):
key = re.sub(pattern, sub, key)
return key
cache = CACHED_SUMMARIES / f"{filter}.jsonl"
if force or not (cache.is_file() and cache.stat().st_size):
compute_summary(filter, force=force, dump=dump)
assert cache.is_file() and cache.stat().st_size, (cache, cache.stat())
if os.isatty(0) and os.isatty(1) and shutil.which("vd"):
rows = read_jsonl(cache)
rows = ({key_mangler(k): v for k, v in row.items()} if row is not None else None for row in rows)
if not all:
rows = filter_jsonl_columns(rows)
rows = ({k: v for k, v in row.items() if not k.startswith(("val_sloss_", "train_sloss_"))} for row in rows)
data = "\n".join(map(json.dumps, rows))
subprocess.run(["vd",
#"--play", EXPERIMENTS / "set-key-columns.vd",
"-f", "jsonl"
], input=data, text=True, check=True)
else:
with cache.open() as f:
print(f.read())
@app.command(help="Filter uninteresting keys from jsonl stdin")
def filter_cols():
rows = map(json.loads, (line for line in sys.stdin.readlines() if line.strip()))
rows = filter_jsonl_columns(rows)
print(*map(json.dumps, rows), sep="\n")
@app.command(help="Run a command for each experiment matched by experiment regex")
def exec(filter: str, cmd: list[str], j: int = typer.Option(1, "-j"), dumped: bool = False, undumped: bool = False):
# inspired by fd / gnu parallel
def populate_cmd(experiment: Path, cmd: Iterable[str]) -> Iterable[str]:
any = False
for i in cmd:
if i == "{}":
any = True
yield str(experiment / "hparams.yaml")
elif i == "{//}":
any = True
yield str(experiment)
else:
yield i
if not any:
yield str(experiment / "hparams.yaml")
with ThreadPoolExecutor(max_workers=j or None) as p:
results = p.map(subprocess.run, (
list(populate_cmd(experiment, cmd))
for experiment in get_experiment_paths(filter)
if not dumped or (experiment / "scalars/epoch.json").is_file()
if not undumped or not (experiment / "scalars/epoch.json").is_file()
))
if any(i.returncode for i in results):
return typer.Exit(1)
@app.command(help="Show stackplot of experiment loss")
def stackplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
if not experiments:
raise typer.BadParameter("No match")
if dump:
try:
tf_dump(experiments)
except typer.BadParameter:
pass
plot_losses(experiments,
mode = "stackplot",
write = write,
dump = dump,
training = training,
unscaled = unscaled,
force = force,
)
@app.command(help="Show stackplot of experiment loss")
def lineplot(filter: str, write: bool = False, dump: bool = False, training: bool = False, unscaled: bool = False, force: bool = False):
experiments = list(get_experiment_paths(filter, assert_dumped=not dump))
if not experiments:
raise typer.BadParameter("No match")
if dump:
try:
tf_dump(experiments)
except typer.BadParameter:
pass
plot_losses(experiments,
mode = "lineplot",
write = write,
dump = dump,
training = training,
unscaled = unscaled,
force = force,
)
@app.command(help="Open tensorboard for the experiments matching the regex")
def tensorboard(filter: Optional[str] = typer.Argument(None), watch: bool = False):
if filter is None:
return list_cached_summaries("tensorboard")
experiments = list(get_experiment_paths(filter, assert_dumped=False))
if not experiments:
raise typer.BadParameter("No match")
with tempfile.TemporaryDirectory(suffix=f"ifield-{filter}") as d:
treefarm = Path(d)
with ThreadPoolExecutor(max_workers=2) as p:
for experiment in experiments:
(treefarm / experiment.name).symlink_to(experiment)
cmd = ["tensorboard", "--logdir", d]
print("+", *map(shlex.quote, cmd), file=sys.stderr)
tensorboard = p.submit(subprocess.run, cmd, check=True)
if not watch:
tensorboard.result()
else:
all_experiments = set(get_experiment_paths(None, assert_dumped=False))
while not tensorboard.done():
time.sleep(10)
new_experiments = set(get_experiment_paths(None, assert_dumped=False)) - all_experiments
if new_experiments:
for experiment in new_experiments:
print(f"Adding {experiment.name!r}...", file=sys.stderr)
(treefarm / experiment.name).symlink_to(experiment)
all_experiments.update(new_experiments)
@app.command(help="Compute evaluation metrics")
def metrics(filter: Optional[str] = typer.Argument(None), dump: bool = False, dry: bool = False, prefix: Optional[str] = typer.Option(None), derive: bool = False, each: bool = False, no_total: bool = False):
if filter is None:
return list_cached_summaries("metrics --derive")
experiments = list(get_experiment_paths(filter, assert_dumped=False))
if not experiments:
raise typer.BadParameter("No match")
if dump:
try:
tf_dump(experiments)
except typer.BadParameter:
pass
def run(*cmd):
if prefix is not None:
cmd = [*shlex.split(prefix), *cmd]
if dry:
print(*map(shlex.quote, map(str, cmd)))
else:
print("+", *map(shlex.quote, map(str, cmd)))
subprocess.run(cmd)
for experiment in experiments:
if no_total: continue
if not (experiment / "compute-scores/metrics.json").is_file():
run(
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics.json",
"--transpose",
)
if not (experiment / "compute-scores/metrics-last.json").is_file():
run(
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics-last.json",
"--transpose",
)
if "2prif-" not in experiment.name: continue
if not (experiment / "compute-scores/metrics-sans_outliers.json").is_file():
run(
"python", "./marf.py", "module", "--best", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics-sans_outliers.json",
"--transpose", "--filter-outliers"
)
if not (experiment / "compute-scores/metrics-last-sans_outliers.json").is_file():
run(
"python", "./marf.py", "module", "--last", experiment / "hparams.yaml",
"compute-scores", experiment / "compute-scores/metrics-last-sans_outliers.json",
"--transpose", "--filter-outliers"
)
if dry: return
if prefix is not None:
print("prefix was used, assuming a job scheduler was used, will not print scores.", file=sys.stderr)
return
metrics = [
*(experiment / "compute-scores/metrics.json" for experiment in experiments),
*(experiment / "compute-scores/metrics-last.json" for experiment in experiments),
*(experiment / "compute-scores/metrics-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
*(experiment / "compute-scores/metrics-last-sans_outliers.json" for experiment in experiments if "2prif-" in experiment.name),
]
if not no_total:
assert all(metric.exists() for metric in metrics)
else:
metrics = (metric for metric in metrics if metric.exists())
out = []
for metric in metrics:
experiment = metric.parent.parent.name
is_last = metric.name in ("metrics-last.json", "metrics-last-sans_outliers.json")
with metric.open() as f:
data = json.load(f)
if derive:
derived = {}
objs = [i for i in data.keys() if i != "_hparams"]
for obj in (objs if each else []) + [None]:
if obj is None:
d = DefaultMunch(0)
for obj in objs:
for k, v in data[obj].items():
d[k] += v
obj = "_all_"
n_cd = data["_hparams"]["n_cd"] * len(objs)
n_emd = data["_hparams"]["n_emd"] * len(objs)
else:
d = munchify(data[obj])
n_cd = data["_hparams"]["n_cd"]
n_emd = data["_hparams"]["n_emd"]
precision = d.TP / (d.TP + d.FP)
recall = d.TP / (d.TP + d.FN)
derived[obj] = dict(
filtered = d.n_outliers / d.n if "n_outliers" in d else None,
iou = d.TP / (d.TP + d.FN + d.FP),
precision = precision,
recall = recall,
f_score = 2 * (precision * recall) / (precision + recall),
cd = d.cd_dist / n_cd,
emd = d.emd / n_emd,
cos_med = 1 - (d.cd_cos_med / n_cd) if "cd_cos_med" in d else None,
cos_jac = 1 - (d.cd_cos_jac / n_cd),
)
data = derived if each else derived["_all_"]
data["uid"] = experiment.rsplit("-", 1)[-1]
data["experiment_name"] = experiment
data["is_last"] = is_last
out.append(json.dumps(data))
if derive and not each and os.isatty(0) and os.isatty(1) and shutil.which("vd"):
subprocess.run(["vd", "-f", "jsonl"], input="\n".join(out), text=True, check=True)
else:
print("\n".join(out))
if __name__ == "__main__":
app()