Files
marf/ifield/cli.py
2025-01-09 15:43:11 +01:00

1007 lines
48 KiB
Python

from . import logging, param
from .utils import helpers
from .utils.helpers import camel_to_snake_case
from argparse import ArgumentParser, _SubParsersAction, Namespace
from contextlib import contextmanager
from datetime import datetime
from functools import partial
from munch import Munch, munchify
from pathlib import Path
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from serve_me_once import serve_once_in_background, gen_random_port
from torch import nn
from tqdm import tqdm
from typing import Optional, Callable, TypeVar, Union, Any
import argparse, collections, copy
import inspect, io, os, platform, psutil, pygments, pygments.lexers, pygments.formatters
import pytorch_lightning as pl, re, rich, rich.pretty, shlex, shutil, string, subprocess, sys, textwrap
import traceback, time, torch, torchviz, urllib.parse, warnings, webbrowser, yaml
CONSOLE = rich.console.Console(width=None if os.isatty(1) else 140)
torch.set_printoptions(threshold=200)
# https://gist.github.com/pypt/94d747fe5180851196eb#gistcomment-3595282
#class UniqueKeyYAMLLoader(yaml.SafeLoader):
class UniqueKeyYAMLLoader(yaml.Loader):
def construct_mapping(self, node, deep=False):
mapping = set()
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
if key in mapping:
raise KeyError(f"Duplicate {key!r} key found in YAML.")
mapping.add(key)
return super().construct_mapping(node, deep)
# load scientific notation correctly as floats and not as strings
# basically, support for the to_json filter in jinja
# https://stackoverflow.com/a/30462009
# https://github.com/yaml/pyyaml/issues/173
UniqueKeyYAMLLoader.add_implicit_resolver(
u'tag:yaml.org,2002:float',
re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.'))
class IgnorantActionsContainer(argparse._ActionsContainer):
"""
Ignores conflicts with
Must be enabled with ArgumentParser(conflict_handler="ignore")
"""
# https://stackoverflow.com/a/71782808
def _handle_conflict_ignore(self, action, conflicting_actions):
pass
argparse.ArgumentParser.__bases__ = (argparse._AttributeHolder, IgnorantActionsContainer)
argparse._ArgumentGroup.__bases__ = (IgnorantActionsContainer,)
@contextmanager
def ignore_action_container_conflicts(parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup]):
old = parser.conflict_handler
parser.conflict_handler = "ignore"
yield
parser.conflict_handler = old
def _print_with_syntax_highlighting(language, string, indent=""):
if os.isatty(1):
string = pygments.highlight(string,
lexer = pygments.lexers.get_lexer_by_name(language),
formatter = pygments.formatters.Terminal256Formatter(style="monokai"),
)
if indent:
string = textwrap.indent(string, indent)
print(string)
def print_column_dict(data: dict, n_columns: int = 2, prefix: str=" "):
small = {k: v for k, v in data.items() if not isinstance(v, dict) and len(repr(v)) <= 40}
wide = {k: v for k, v in data.items() if not isinstance(v, dict) and len(repr(v)) > 40}
dicts = {k: v for k, v in data.items() if isinstance(v, dict)}
kw = dict(
crop = False,
overflow = "ignore",
)
if small:
CONSOLE.print(helpers.columnize_dict(small, prefix=prefix, n_columns=n_columns, sep=" "), **kw)
key_len = max(map(len, map(repr, wide.keys()))) if wide else 0
for key, val in wide.items():
CONSOLE.print(f"{prefix}{repr(key).ljust(key_len)} : {val!r},", **kw)
for key, val in dicts.items():
CONSOLE.print(f"{prefix}{key!r}: {{", **kw)
print_column_dict(val, n_columns=n_columns, prefix=prefix+" ")
CONSOLE.print(f"{prefix}}},", **kw)
M = TypeVar("M", bound=nn.Module)
DM = TypeVar("DM", bound=pl.LightningDataModule)
FitHook = Callable[[Namespace, Munch, M, pl.Trainer, DM, logging.Logger], None]
class CliInterface:
trainer_defaults: dict
def __init__(self, *, module_cls: type[M], workdir: Path, datamodule_cls: Union[list[type[DM]], type[DM], None] = None, experiment_name_prefix = "experiment"):
self.module_cls = module_cls
self.datamodule_cls = [datamodule_cls] if not isinstance(datamodule_cls, list) and datamodule_cls is not None else datamodule_cls
self.workdir = workdir
self.experiment_name_prefix = experiment_name_prefix
self.trainer_defaults = dict(
enable_model_summary = False,
)
self.pre_fit_handlers: list[FitHook] = []
self.post_fit_handlers: list[FitHook] = []
self._registered_actions : dict[str, tuple[Callable[M, None], list, dict, Optional[callable]]] = {}
self._included_in_config_template : dict[str, tuple[callable, dict]] = {}
self.register_action(_func=self.repr, help="Print str(module).", args=[])
self.register_action(_func=self.yaml, help="Print evaluated config.", args=[])
self.register_action(_func=self.hparams, help="Print hparams, like during training.", args=[])
self.register_action(_func=self.dot, help="Print graphviz graph of computation graph.", args=[
("-e", "--eval", dict(action="store_true")),
("-f", "--filter", dict(action="store_true")),
])
self.register_action(_func=self.jit, help="Print a TorchScript graph of the model", args=[])
self.register_action(_func=self.trace, help="Dump a TorchScript trace of the model.", args=[
("output_file", dict(type=Path,
help="Path to write the .pt file. Use \"-\" to instead open the trace in Netron.app")),
])
self.register_action(_func=self.onnx, help="Dump a ONNX trace of the model.", args=[
("output_file", dict(type=Path,
help="Path to write the .onnx file. Use \"-\" to instead open the onnx in Netron.app")),
])
if self.datamodule_cls:
names = [i.__name__ for i in self.datamodule_cls]
names_snake = [datamodule_name_to_snake_case(i) for i in names]
assert len(names) == len(set(names)),\
f"Datamodule names are not unique: {names!r}"
assert len(names) == len(set(names_snake)),\
f"Datamodule snake-names are not unique: {names_snake!r}"
self.register_action(_func=self.test_dataloader,
help="Benchmark the speed of the dataloader",
args=[
("datamodule", dict(type=str, default=None, nargs='?', choices=names_snake,
help="Which dataloader to test. Defaults to the first one found in config.")),
("--limit-cores", dict(type=int, default=None,
help="Limits the cpu affinity to N cores. Perfect to simulate a SLURM environ.")),
("--profile", dict(type=Path, default=None,
help="Profile using cProfile, marshaling the result to a .prof or .log file.")),
("-n", "--n-rounds", dict(type=int, default=3,
help="Number of times to read the dataloader.")),
],
conflict_handler = "ignore" if len(self.datamodule_cls) > 1 else "error",
add_argparse_args=[i.add_argparse_args for i in self.datamodule_cls],
)
# decorator
def register_pre_training_callback(self, func: FitHook):
self.pre_fit_handlers.append(func)
return func
# decorator
def register_post_training_callback(self, func: FitHook):
self.post_fit_handlers.append(func)
return func
# decorator
def register_action(self, *,
help : str,
args : list[tuple[Any, ..., dict]] = [],
_func : Optional[Callable[[Namespace, Munch, M], None]] = None,
add_argparse_args : Union[list[Callable[[ArgumentParser], ArgumentParser]], Callable[[ArgumentParser], ArgumentParser], None] = None,
**kw,
):
def wrapper(action: Callable[[Namespace, Munch, M], None]):
cli_name = action.__name__.lower().replace("_", "-")
self._registered_actions[cli_name] = (
action,
args,
kw | {"help": help},
add_argparse_args,
)
return action
if _func is not None: # shortcut
return wrapper(_func)
else:
return wrapper
def make_parser(self,
parser : ArgumentParser = None,
subparsers : _SubParsersAction = None,
add_trainer : bool = False,
) -> tuple[ArgumentParser, _SubParsersAction, _SubParsersAction]:
if parser is None:
parser = ArgumentParser()
if subparsers is None:
subparsers = parser.add_subparsers(dest="mode", required=True)
parser.add_argument("-pm", "--post-mortem", action="store_true",
help="Start a debugger if a uncaught exception is thrown.")
# Template generation and exploration
parser_template = subparsers.add_parser("template",
help="Generate or evaluate a config template")
if 1: # fold me
parser_mode_mutex = parser_template.add_mutually_exclusive_group()#(required=True)
parser_mode_mutex.add_argument("-e", "--evaluate", metavar="TEMPLATE", type=Path,
help="Read jinja2 yaml config template file, then evaluate and print it.")
parser_mode_mutex.add_argument("-p", "--parse", metavar="TEMPLATE", type=Path,
help="Read jinja2 yaml config template file, then evaluate, parse and print it.")
def pair(data: str) -> tuple[str, str]:
key, sep, value = data.partition("=")
if not sep:
if key in os.environ:
value = os.environ[key]
else:
raise ValueError(f"the variable {key!r} was not given any value, and none was found in the environment.")
elif "$" in value:
value = string.Template(value).substitute(os.environ)
return (key, value)
parser_template.add_argument("-O", dest="jinja2_variables", action="append", type=pair,
help="Variable available as string in the jinja2. (a=b). b will be expanded as an"
" env var if prefixed with $, or set equal to the env var a if =b is omitted.")
parser_template.add_argument("-s", "--strict", action="store_true",
help="Enable {% do require_defined(\"var\",var) %}".replace("%", "%%"))
parser_template.add_argument("-d", "--defined-only", action="store_true",
help="Disallow any use of undefined variables")
# Load a module
parser_module = subparsers.add_parser("module", aliases=["model"],
help="Load a config template, evaluate it and use the resulting module")
if 1: # fold me
parser_module.add_argument("module_file", type=Path,
help="Jinja2 yaml config template or pytorch-lightning .ckpt file.")
parser_module.add_argument("-O", dest="jinja2_variables", action="append", type=pair,
help="Variable available as string in the jinja2. (a=b). b will be expanded as an"
" env var if prefixed with $, or set equal to the env var a if =b is omitted.")
parser_module.add_argument("--last", action="store_true",
help="if multiple ckpt match, prefer the last one")
parser_module.add_argument("--best", action="store_true",
help="if multiple ckpt match, prefer the best one")
parser_module.add_argument("--add-shape-prehook", action="store_true",
help="Add a forward hook which prints the tensor shapes of all inputs, but not the outputs.")
parser_module.add_argument("--add-shape-hook", action="store_true",
help="Add a forward hook which prints the tensor shapes of all inputs AND outputs.")
parser_module.add_argument("--add-oob-hook", action="store_true",
help="Add a forward hook checking for INF and NaN values in inputs or outputs.")
parser_module.add_argument("--add-oob-hook-input", action="store_true",
help="Add a forward hook checking for INF and NaN values in inputs.")
parser_module.add_argument("--add-oob-hook-output", action="store_true",
help="Add a forward hook checking for INF and NaN values in outputs.")
module_actions_subparser = parser_module.add_subparsers(dest="action", required=True)
# add pluggables
for name, (action, args, kw, add_argparse_args) in self._registered_actions.items():
action_parser = module_actions_subparser.add_parser(name, **kw)
if add_argparse_args is not None and add_argparse_args:
for func in add_argparse_args if isinstance(add_argparse_args, list) else [add_argparse_args]:
action_parser = func(action_parser)
for *a, kw in args:
action_parser.add_argument(*a, **kw)
# Module: train or test
if self.datamodule_cls:
parser_trainer = module_actions_subparser.add_parser("fit", aliases=["test"],
help="Train/fit or evaluate the module with train/val or test data.")
# pl.Trainer
parser_trainer = pl.Trainer.add_argparse_args(parser_trainer)
# datamodule
parser_trainer.add_argument("datamodule", type=str, default=None, nargs='?',
choices=[datamodule_name_to_snake_case(i) for i in self.datamodule_cls],
help="Which dataloader to test. Defaults to the first one found in config.")
if len(self.datamodule_cls) > 1:
# check that none of the datamodules conflict with trainer or module
for datamodule_cls in self.datamodule_cls:
datamodule_cls.add_argparse_args(copy.deepcopy(parser_trainer)) # will raise on conflict
# Merge the datamodule options, the above sanity check makes it "okay"
with ignore_action_container_conflicts(parser_trainer):
for datamodule_cls in self.datamodule_cls:
parser_trainer = datamodule_cls.add_argparse_args(parser_trainer)
# defaults and jinja template
self._included_in_config_template.clear()
remove_options_from_parser(parser_trainer, "--logger")
parser_trainer.set_defaults(**self.trainer_defaults)
self.add_to_jinja_template("trainer", pl.Trainer, defaults=self.trainer_defaults, exclude_list={
# not yaml friendly, already covered anyway:
"logger",
"plugins",
"callbacks",
# deprecated or covered by callbacks:
"stochastic_weight_avg",
"enable_model_summary",
"track_grad_norm",
"log_gpu_memory",
})
for datamodule_cls in self.datamodule_cls:
self.add_to_jinja_template(datamodule_cls.__name__, datamodule_cls,
comment=f"select with {datamodule_name_to_snake_case(datamodule_cls)!r}")#, commented=False)
self.add_to_jinja_template("logging", logging, save_dir = "logdir", commented=False)
return parser, subparsers, module_actions_subparser
def add_to_jinja_template(self, name: str, func: callable, **kwargs):
"""
Basically a call to `make_jinja_template`.
Will ensure the keys are present in the output from `from_argparse_args`.
"""
self._included_in_config_template[name] = (func, dict(commented=True) | kwargs)
def make_jinja_template(self) -> str:
return "\n".join([
f'#!/usr/bin/env -S python {sys.argv[0]} module',
r'{% do require_defined("select", select, 0, "$SLURM_ARRAY_TASK_ID") %}{# requires jinja2.ext.do #}',
r"{% set counter = itertools.count(start=0, step=1) %}",
r"",
r"{% set hp_matrix = namespace() %}{# hyper parameter matrix #}",
r"{% set hp_matrix.my_hparam = [0] %}{##}",
r"",
r"{% for hp in cartesian_hparams(hp_matrix) %}{##}",
r"{#% for hp in ablation_hparams(hp_matrix, caartesian_keys=[]) %}{##}",
r"",
r"{% set index = next(counter) %}",
r"{% if select is not defined and index > 0 %}---{% endif %}",
r"{% if select is not defined or int(select) == index %}",
r"",
*[
func.make_jinja_template(name=name, **kwargs)
if hasattr(func, "make_jinja_template") else
param.make_jinja_template(func, name=name, **kwargs)
for name, (func, kwargs) in self._included_in_config_template.items()
],
r"{% autoescape false %}",
r'{% do require_defined("experiment_name", experiment_name, "test", strict=true) %}',
f"experiment_name: { self.experiment_name_prefix }-{{{{ experiment_name }}}}",
r'{#--#}-{{ hp.my_hparam }}',
r'{#--#}-{{ gen_run_uid(4) }} # select with -Oselect={{ index }}',
r"{% endautoescape %}",
self.module_cls.make_jinja_template(),
r"{% endif %}{# -Oselect #}",
r"",
r"{% endfor %}",
r"",
r"{% set index = next(counter) %}",
r"# number of possible 'select': {{ index }}, from 0 to {{ index-1 }}",
r"# local: for select in {0..{{ index-1 }}}; do python ... -Oselect=$select ... ; done",
r"# local: for select in {0..{{ index-1 }}}; do python -O {{ argv[0] }} model marf.yaml.j2 -Oselect=$select -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu ; done",
r"# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python ... -Oselect=\$SLURM_ARRAY_TASK_ID ...",
r"# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python -O {{ argv[0] }} model this-file.yaml.j2 -Oselect=\$SLURM_ARRAY_TASK_ID -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu --devices -1 --strategy ddp"
])
def run(self, args=None, args_hook: Optional[Callable[[ArgumentParser, _SubParsersAction, _SubParsersAction], None]] = None):
parser, mode_subparser, action_subparser = self.make_parser()
if args_hook is not None:
args_hook(parser, mode_subparser, action_subparser)
args = parser.parse_args(args) # may exit
if os.isatty(0) and args.post_mortem:
warnings.warn("post-mortem debugging is enabled without any TTY attached. Will be ignored.")
if args.post_mortem and os.isatty(0):
try:
self.handle_args(args)
except Exception:
# print exception
sys.excepthook(*sys.exc_info())
# debug
*debug_module, debug_func = os.environ.get("PYTHONBREAKPOINT", "pdb.set_trace").split(".")
__import__(".".join(debug_module)).post_mortem()
exit(1)
else:
self.handle_args(args)
def handle_args(self, args: Namespace):
"""
May call exit()
"""
if args.mode == "template":
if args.evaluate or args.parse:
template_file = args.evaluate or args.parse
env = param.make_jinja_env(globals=param.make_jinja_globals(enable_require_defined=args.strict), allow_undef=not args.defined_only)
if str(template_file) == "-":
template = env.from_string(sys.stdin.read(), globals=dict(args.jinja2_variables or []))
else:
template = env.get_template(str(template_file.absolute()), globals=dict(args.jinja2_variables or []))
config_yaml = param.squash_newlines(template.render())#.lstrip("\n").rstrip()
if args.evaluate:
_print_with_syntax_highlighting("yaml+jinja", config_yaml)
else:
config = yaml.load(config_yaml, UniqueKeyYAMLLoader)
CONSOLE.print(config)
else:
_print_with_syntax_highlighting("yaml+jinja", self.make_jinja_template())
elif args.mode in ("module", "model"):
module: nn.Module
if not args.module_file.is_file():
matches = [*Path("logdir/tensorboard").rglob(f"*-{args.module_file}/checkpoints/*.ckpt")]
if len(matches) == 1:
args.module_file, = matches
elif len(matches) > 1:
if (args.last or args.best) and len(set(match.parent.parent.name for match in matches)) == 1:
if args.last:
args.module_file, = (match for match in matches if match.name == "last.ckpt")
elif args.best:
args.module_file, = (match for match in matches if match.name.startswith("epoch="))
else:
assert False
else:
raise ValueError("uid matches multiple paths:\n"+"\n".join(map(str, matches)))
else:
raise ValueError("path does not exist, and is not a uid")
# load module from cli args
if args.module_file.suffix == ".ckpt": # from checkpoint
# load from checkpoint
rich.print(f"Loading module from {str(args.module_file)!r}...", file=sys.stderr)
module = self.module_cls.load_from_checkpoint(args.module_file)
if (args.module_file.parent.parent / "hparams.yaml").is_file():
with (args.module_file.parent.parent / "hparams.yaml").open() as f:
config_yaml = yaml.load(f.read(), UniqueKeyYAMLLoader)["_pickled_cli_args"]["_raw_yaml"]
else:
with (args.module_file.parent.parent / "config.yaml").open() as f:
config_yaml = f.read()
config = munchify(yaml.load(config_yaml, UniqueKeyYAMLLoader) | {"_raw_yaml": config_yaml})
else: # from yaml
# read, evaluate and parse config
if args.module_file.suffix == ".j2" or str(args.module_file) == "-":
env = param.make_jinja_env()
if str(args.module_file) == "-":
template = env.from_string(sys.stdin.read(), globals=dict(args.jinja2_variables or []))
else: # jinja+yaml file
template = env.get_template(str(args.module_file.absolute()), globals=dict(args.jinja2_variables or []))
config_yaml = param.squash_newlines(template.render()).lstrip("\n").rstrip()
else: # yaml file (the git diffs in _pickled_cli_args may trigger jinja's escape sequences)
with args.module_file.open() as f:
config_yaml = f.read().lstrip("\n").rstrip()
config = yaml.load(config_yaml, UniqueKeyYAMLLoader)
if "_pickled_cli_args" in config: # hparams.yaml in tensorboard logdir
config_yaml = config["_pickled_cli_args"]["_raw_yaml"]
config = yaml.load(config_yaml, UniqueKeyYAMLLoader)
from_checkpoint: Optional[Path] = None
if (args.module_file.parent / "checkpoints").glob("*.ckpt"):
checkpoints_fnames = list((args.module_file.parent / "checkpoints").glob("*.ckpt"))
if len(checkpoints_fnames) == 1:
from_checkpoint = checkpoints_fnames[0]
elif args.last:
from_checkpoint, = (i for i in checkpoints_fnames if i.name == "last.ckpt")
elif args.best:
from_checkpoint, = (i for i in checkpoints_fnames if i.name.startswith("epoch="))
elif len(checkpoints_fnames) > 1:
rich.print(f"[yellow]WARNING:[/] {str(args.module_file.parent / 'checkpoints')!r} contains more than one checkpoint, unable to automatically load one.", file=sys.stderr)
config = munchify(config | {"_raw_yaml": config_yaml})
# Ensure date and uid to experiment name, allowing for reruns and organization
assert config.experiment_name
assert re.match(r'^.*-[0-9]{4}-[0-9]{2}-[0-9]{2}-[0-9]{4}-[a-z]{4}$', config.experiment_name),\
config.experiment_name
# init the module
if from_checkpoint:
rich.print(f"Loading module from {str(from_checkpoint)!r}...", file=sys.stderr)
module = self.module_cls.load_from_checkpoint(from_checkpoint)
else:
module = self.module_cls(**{k:v for k, v in config[self.module_cls.__name__].items() if k != "_extra"})
# optional debugging forward hooks
if args.add_shape_hook or args.add_shape_prehook:
def shape_forward_hook(is_prehook: bool, name: str, module: nn.Module, input, output=None):
def tensor_to_shape(val):
if isinstance(val, torch.Tensor):
return tuple(val.shape)
elif isinstance(val, (str, float, int)) or val is None:
return 1
else:
assert 0, (val, name)
with torch.no_grad():
rich.print(
f"{name}.forward({helpers.map_tree(tensor_to_shape, input)})"
if is_prehook else
f"{name}.forward({helpers.map_tree(tensor_to_shape, input)})"
f" -> {helpers.map_tree(tensor_to_shape, output)}"
, file=sys.stderr)
for submodule_name, submodule in module.named_modules():
if submodule_name:
submodule_name = f"{module.__class__.__qualname__}.{submodule_name}"
else:
submodule_name = f"{module.__class__.__qualname__}"
if args.add_shape_prehook:
submodule.register_forward_pre_hook(partial(shape_forward_hook, True, submodule_name))
if args.add_shape_hook:
submodule.register_forward_hook(partial(shape_forward_hook, False, submodule_name))
if args.add_oob_hook or args.add_oob_hook_input or args.add_oob_hook_output:
def oob_forward_hook(name: str, module: nn.Module, input, output):
def raise_if_oob(key, val):
if isinstance(val, collections.abc.Mapping):
for k, subval in val.items():
raise_if_oob(f"{key}[{k!r}]", subval)
elif isinstance(val, (tuple, list)):
for i, subval in enumerate(val):
raise_if_oob(f"{key}[{i}]", subval)
elif isinstance(val, torch.Tensor):
assert not torch.isinf(val).any(), \
f"INFs found in {key}"
assert not val.isnan().any(), \
f"NaNs found in {key}"
elif isinstance(val, (str, float, int)):
pass
elif val is None:
warnings.warn(f"None found in {key}")
else:
assert False, val
with torch.no_grad():
if args.add_oob_hook or args.add_oob_hook_input:
raise_if_oob(f"{name}.forward input", input)
if args.add_oob_hook or args.add_oob_hook_output:
raise_if_oob(f"{name}.forward output", output)
for submodule_name, submodule in module.named_modules():
submodule.register_forward_hook(partial(oob_forward_hook,
f"{module.__class__.__qualname__}.{submodule_name}"
if submodule_name else
f"{module.__class__.__qualname__}"
))
# Ensure all the top-level config keys are there
for key in self._included_in_config_template.keys():
if key in (i.__name__ for i in self.datamodule_cls):
continue
if key not in config or config[key] is None:
config[key] = {}
# Run registered action
if args.action in self._registered_actions:
action, *_ = self._registered_actions[args.action]
action(args, config, module)
elif args.action in ("fit", "test") and self.datamodule_cls is not None:
self.fit(args, config, module)
else:
raise ValueError(f"{args.mode=}, {args.action=}")
else:
raise ValueError(f"{args.mode=}")
def get_datamodule_cls_from_config(self, args: Namespace, config: Munch) -> DM:
assert self.datamodule_cls
cli = getattr(args, "datamodule", None)
datamodule_cls: pl.LightningDataModule
if cli is not None:
datamodule_cls, = (i for i in self.datamodule_cls if datamodule_name_to_snake_case(i) == cli)
else:
datamodules = {
cls.__name__: cls
for cls in self.datamodule_cls
}
for key in config.keys():
if key in datamodules:
datamodule_cls = datamodules[key]
break
else:
datamodule_cls = self.datamodule_cls[0]
warnings.warn(f"None of the following datamodules were found in config: {set(datamodules.keys())!r}. {datamodule_cls.__name__!r} was chosen as the default.")
return datamodule_cls
def init_datamodule_cls_from_config(self, args: Namespace, config: Munch) -> DM:
datamodule_cls = self.get_datamodule_cls_from_config(args, config)
return datamodule_cls.from_argparse_args(args, **(config.get(datamodule_cls.__name__) or {}))
# Module actions
def repr(self, args: Namespace, config: Munch, module: M):
rich.print(module)
def yaml(self, args: Namespace, config: Munch, module: M):
_print_with_syntax_highlighting("yaml+jinja", config["_raw_yaml"])
def dot(self, args: Namespace, config: Munch, module: M):
module.train(not args.eval)
assert not args.filter, "not implemented! pipe it through examples/scripts/filter_dot.py in the meanwhile"
example_input_array = module.example_input_array
assert example_input_array is not None, f"{module.__class__.__qualname__}.example_input_array=None"
assert isinstance(example_input_array, (tuple, dict, torch.Tensor)), type(example_input_array)
def set_requires_grad(val):
if isinstance(val, torch.Tensor):
val.requires_grad = True
return val
with torch.enable_grad():
outputs = module(*helpers.map_tree(set_requires_grad, example_input_array))
dot = torchviz.make_dot(outputs, params=dict(module.named_parameters()), show_attrs=False, show_saved=False)
_print_with_syntax_highlighting("dot", str(dot))
def jit(self, args: Namespace, config: Munch, module: M):
example_input_array = module.example_input_array
assert example_input_array is not None, f"{module.__class__.__qualname__}.example_input_array=None"
assert isinstance(example_input_array, (tuple, dict, torch.Tensor)), type(example_input_array)
trace = torch.jit.trace_module(module, {"forward": example_input_array})
_print_with_syntax_highlighting("python", str(trace.inlined_graph))
def trace(self, args: Namespace, config: Munch, module: M):
if isinstance(module, pl.LightningModule):
trace = module.to_torchscript(method="trace")
else:
example_input_array = module.example_input_array
assert example_input_array is not None, f"{module.__class__.__qualname__}.example_input_array is None"
assert isinstance(module, torch.Module)
trace = torch.jit.trace_module(module, {"forward": example_input_array})
use_netron = str(args.output_file) == "-"
trace_f = io.BytesIO() if use_netron else args.output_file
torch.jit.save(trace, trace_f)
if use_netron:
open_in_netron(f"{self.module_cls.__name__}.pt", trace_f.getvalue())
def onnx(self, args: Namespace, config: Munch, module: M):
example_input_array = module.example_input_array
assert example_input_array is not None, f"{module.__class__.__qualname__}.example_input_array=None"
assert isinstance(example_input_array, (tuple, dict, torch.Tensor)), type(example_input_array)
use_netron = str(args.output_file) == "-"
onnx_f = io.BytesIO() if use_netron else args.output_file
torch.onnx.export(module,
tuple(example_input_array),
onnx_f,
export_params = True,
opset_version = 17,
do_constant_folding = True,
input_names = ["input"],
output_names = ["output"],
dynamic_axes = {
"input" : {0 : "batch_size"},
"output" : {0 : "batch_size"},
},
)
if use_netron:
open_in_netron(f"{self.module_cls.__name__}.onnx", onnx_f.getvalue())
def hparams(self, args: Namespace, config: Munch, module: M):
assert isinstance(module, self.module_cls)
print(f"{self.module_cls.__qualname__} hparams:")
print_column_dict(map_type_to_repr(module.hparams, nn.Module, lambda t: f"{t.__class__.__qualname__}"), 3)
def fit(self, args: Namespace, config: Munch, module: M):
is_rank_zero = pl.utilities.rank_zero_only.rank == 0
metric_prefix = f"{module.__class__.__name__}.validation_step/"
pl_callbacks = [
pl.callbacks.LearningRateMonitor(log_momentum=True),
pl.callbacks.EarlyStopping(monitor=metric_prefix+getattr(module, "metric_early_stop", "loss"), patience=200, check_on_train_epoch_end=False, verbose=True),
pl.callbacks.ModelCheckpoint(monitor=metric_prefix+getattr(module, "metric_best_model", "loss"), mode="min", save_top_k=1, save_last=True),
logging.ModelOutputMonitor(),
logging.EpochTimeMonitor(),
(pl.callbacks.RichModelSummary if os.isatty(1) else pl.callbacks.ModelSummary)(max_depth=30),
logging.PsutilMonitor(),
]
if os.isatty(1):
pl_callbacks.append( pl.callbacks.RichProgressBar() )
trainer: pl.Trainer
logger = logging.make_logger(config.experiment_name, config.trainer.get("default_root_dir", args.default_root_dir or self.workdir), **config.logging)
trainer = pl.Trainer.from_argparse_args(args, logger=logger, callbacks=pl_callbacks, **config.trainer)
datamodule = self.init_datamodule_cls_from_config(args, config)
for f in self.pre_fit_handlers:
print(f"pre-train hook {f.__name__!r}...")
f(args, config, module, trainer, datamodule, logger)
# print and log hparams/config
if 1: # fold me
if is_rank_zero:
CONSOLE.print(f"Experiment name: {config.experiment_name!r}", soft_wrap=False, crop=False, no_wrap=False, overflow="ignore")
# parser.args and sys.argv
pickled_cli_args = dict(
sys_argv = sys.argv,
parser_args = args.__dict__,
config = config.copy(),
_raw_yaml = config["_raw_yaml"],
)
del pickled_cli_args["config"]["_raw_yaml"]
for k,v in pickled_cli_args["parser_args"].items():
if isinstance(v, Path):
pickled_cli_args["parser_args"][k] = str(v)
# trainer
params_trainer = inspect.signature(pl.Trainer.__init__).parameters
trainer_hparams = vars(pl.Trainer.parse_argparser(args))
trainer_hparams = { name: trainer_hparams[name] for name in params_trainer if name in trainer_hparams }
if is_rank_zero:
print("pl.Trainer hparams:")
print_column_dict(trainer_hparams, 3)
pickled_cli_args.update(trainer_hparams=trainer_hparams)
# module
assert isinstance(module, self.module_cls)
if is_rank_zero:
print(f"{self.module_cls.__qualname__} hparams:")
print_column_dict(map_type_to_repr(module.hparams, nn.Module, lambda t: f"{t.__class__.__qualname__}"), 3)
pickled_cli_args.update(module_hparams={
k : v
for k, v in module.hparams.items()
if k != "_raw_yaml"
})
# module extra state, like autodecoder uids
for submodule_name, submodule in module.named_modules():
if not submodule_name:
submodule_name = module.__class__.__qualname__
else:
submodule_name = module.__class__.__qualname__ + "." + submodule_name
try:
state = submodule.get_extra_state()
except RuntimeError:
continue
if "extra_state" not in pickled_cli_args:
pickled_cli_args["extra_state"] = {}
pickled_cli_args["extra_state"][submodule_name] = state
# datamodule
if self.datamodule_cls:
assert datamodule is not None and any(isinstance(datamodule, i) for i in self.datamodule_cls), datamodule
for datamodule_cls in self.datamodule_cls:
params_d = inspect.signature(datamodule_cls.__init__).parameters
assert {"self"} == set(params_trainer).intersection(params_d), \
f"trainer and datamodule has overlapping params: {set(params_trainer).intersection(params_d) - {'self'}}"
if is_rank_zero:
print(f"{datamodule.__class__.__qualname__} hparams:")
print_column_dict(datamodule.hparams)
pickled_cli_args.update(datamodule_hparams=dict(datamodule.hparams))
# logger
if logger is not None:
print(f"{logger.__class__.__qualname__} hparams:")
print_column_dict(config.logging)
pickled_cli_args.update(logger_hparams = {"_class": logger.__class__.__name__} | config.logging)
# host info
def cmd(cmd: Union[str, list[str]]) -> str:
if isinstance(cmd, str):
cmd = shlex.split(cmd)
if shutil.which(cmd[0]):
try:
return subprocess.run(cmd,
capture_output=True,
check=True,
text=True,
).stdout.strip()
except subprocess.CalledProcessError as e:
warnings.warn(f"{e.__class__.__name__}: {e}")
return f"{e.__class__.__name__}: {e}\n{e.output = }\n{e.stderr = }"
else:
warnings.warn(f"command {cmd[0]!r} not found")
return f"*command {cmd[0]!r} not found*"
pickled_cli_args.update(host = dict(
platform = textwrap.dedent(f"""
{platform.architecture() = }
{platform.java_ver() = }
{platform.libc_ver() = }
{platform.mac_ver() = }
{platform.machine() = }
{platform.node() = }
{platform.platform() = }
{platform.processor() = }
{platform.python_branch() = }
{platform.python_build() = }
{platform.python_compiler() = }
{platform.python_implementation() = }
{platform.python_revision() = }
{platform.python_version() = }
{platform.release() = }
{platform.system() = }
{platform.uname() = }
{platform.version() = }
""".rstrip()).lstrip(),
cuda = dict(
gpus = [
torch.cuda.get_device_name(i)
for i in range(torch.cuda.device_count())
],
available = torch.cuda.is_available(),
version = torch.version.cuda,
),
hostname = cmd("hostname --fqdn"),
cwd = os.getcwd(),
date = datetime.now().astimezone().isoformat(),
date_utc = datetime.utcnow().isoformat(),
ifconfig = cmd("ifconfig"),
lspci = cmd("lspci"),
lsusb = cmd("lsusb"),
lsblk = cmd("lsblk"),
mount = cmd("mount"),
environ = os.environ.copy(),
vcs = f"commit {cmd('git rev-parse HEAD')}\n{cmd('git status')}\n{cmd('git diff --stat --patch HEAD')}",
venv_pip = cmd("pip list --format=freeze"),
venv_conda = cmd("conda list"),
venv_poetry = cmd("poetry show -t"),
gpus = [i.split(", ") for i in cmd("nvidia-smi --query-gpu=index,name,memory.total,driver_version,uuid --format=csv").splitlines()],
))
if logger is not None:
logging.log_config(logger, _pickled_cli_args=pickled_cli_args)
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning)
if __debug__ and is_rank_zero:
warnings.warn("You're running python with assertions active. Enable optimizations with `python -O` for improved performance.")
# train
t_begin = datetime.now()
if args.action == "fit":
trainer.fit(module, datamodule)
elif args.action == "test":
trainer.test(module, datamodule)
else:
raise ValueError(f"{args.mode=}, {args.action=}")
if not is_rank_zero:
return
t_end = datetime.now()
print(f"Training time: {t_end - t_begin}")
for f in self.post_fit_handlers:
print(f"post-train hook {f.__name__!r}...")
try:
f(args, config, module, trainer, datamodule, logger)
except Exception:
traceback.print_exc()
rich.print(f"Experiment name: {config.experiment_name!r}")
rich.print(f"Best model path: {helpers.make_relative(trainer.checkpoint_callback.best_model_path).__str__()!r}")
rich.print(f"Last model path: {helpers.make_relative(trainer.checkpoint_callback.last_model_path).__str__()!r}")
def test_dataloader(self, args: Namespace, config: Munch, module: M):
# limit CPU affinity
if args.limit_cores is not None:
# https://stackoverflow.com/a/40856471
p = psutil.Process()
assert len(p.cpu_affinity()) >= args.limit_cores
cpus = list(range(args.limit_cores))
p.cpu_affinity(cpus)
print("Process limited to CPUs", cpus)
datamodule = self.init_datamodule_cls_from_config(args, config)
# setup
rich.print(f"Setup {datamodule.__class__.__qualname__}...")
datamodule.prepare_data()
datamodule.setup("fit")
try:
train = datamodule.train_dataloader()
except (MisconfigurationException, NotImplementedError):
train = None
try:
val = datamodule.val_dataloader()
except (MisconfigurationException, NotImplementedError):
val = None
try:
test = datamodule.test_dataloader()
except (MisconfigurationException, NotImplementedError):
test = None
# inspect
rich.print("batch[0] = ", end="")
rich.pretty.pprint(
map_type_to_repr(
next(iter(train)),
torch.Tensor,
lambda x: f"Tensor(..., shape={x.shape}, dtype={x.dtype}, device={x.device})",
),
indent_guides = False,
)
if args.profile is not None:
import cProfile
profile = cProfile.Profile()
profile.enable()
# measure
n_train, td_train = 0, 0
n_val, td_val = 0, 0
n_test, td_test = 0, 0
try:
for i in range(args.n_rounds):
print(f"Round {i+1} of {args.n_rounds}")
if train is not None:
epoch = time.perf_counter_ns()
n_train += sum(1 for _ in tqdm(train, desc=f"train {i+1}/{args.n_rounds}"))
td_train += time.perf_counter_ns() - epoch
if val is not None:
epoch = time.perf_counter_ns()
n_val += sum(1 for _ in tqdm(val, desc=f"val {i+1}/{args.n_rounds}"))
td_val += time.perf_counter_ns() - epoch
if test is not None:
epoch = time.perf_counter_ns()
n_test += sum(1 for _ in tqdm(test, desc=f"train {i+1}/{args.n_rounds}"))
td_test += time.perf_counter_ns() - epoch
except KeyboardInterrupt:
rich.print("Recieved a `KeyboardInterrupt`...")
if args.profile is not None:
profile.disable()
if args.profile != "-":
profile.dump_stats(args.profile)
profile.print_stats("tottime")
# summary
for label, data, n, td in [
("train", train, n_train, td_train),
("val", val, n_val, td_val),
("test", test, n_test, td_test),
]:
if not n: continue
if data is not None:
print(f"{label}:",
f" - per epoch: {td / args.n_rounds * 1e-9 :11.6f} s",
f" - per batch: {td / n * 1e-9 :11.6f} s",
f" - batches/s: {n / (td * 1e-9):11.6f}",
sep="\n")
datamodule.teardown("fit")
# helpers:
def open_in_netron(filename: str, data: bytes, *, timeout: float = 10):
# filename is only used to determine the filetype
url = serve_once_in_background(
data,
mime_type = "application/octet-stream",
timeout = timeout,
port = gen_random_port(),
)
url = f"https://netron.app/?url={urllib.parse.quote(url)}{filename}"
print("Open in Netron:", url)
webbrowser.get("firefox").open_new_tab(url)
if timeout:
time.sleep(timeout)
def remove_options_from_parser(parser: ArgumentParser, *options: str):
options = set(options)
# https://stackoverflow.com/questions/32807319/disable-remove-argument-in-argparse/36863647#36863647
for action in parser._actions:
if action.option_strings:
option = action.option_strings[0]
if option in options:
parser._handle_conflict_resolve(None, [(option, action)])
def map_type_to_repr(batch, type_match: type, repr_func: callable):
def mapper(value):
if isinstance(value, type_match):
return helpers.CustomRepr(repr_func(value))
else:
return value
return helpers.map_tree(mapper, batch)
def datamodule_name_to_snake_case(datamodule: Union[str, type[DM]]) -> str:
if not isinstance(datamodule, str):
datamodule = datamodule.__name__
datamodule = datamodule.replace("DataModule", "Datamodule")
if datamodule != "Datamodule":
datamodule = datamodule.removesuffix("Datamodule")
return camel_to_snake_case(datamodule, sep="-", join_abbreviations=True)