1007 lines
48 KiB
Python
1007 lines
48 KiB
Python
from . import logging, param
|
|
from .utils import helpers
|
|
from .utils.helpers import camel_to_snake_case
|
|
from argparse import ArgumentParser, _SubParsersAction, Namespace
|
|
from contextlib import contextmanager
|
|
from datetime import datetime
|
|
from functools import partial
|
|
from munch import Munch, munchify
|
|
from pathlib import Path
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from serve_me_once import serve_once_in_background, gen_random_port
|
|
from torch import nn
|
|
from tqdm import tqdm
|
|
from typing import Optional, Callable, TypeVar, Union, Any
|
|
import argparse, collections, copy
|
|
import inspect, io, os, platform, psutil, pygments, pygments.lexers, pygments.formatters
|
|
import pytorch_lightning as pl, re, rich, rich.pretty, shlex, shutil, string, subprocess, sys, textwrap
|
|
import traceback, time, torch, torchviz, urllib.parse, warnings, webbrowser, yaml
|
|
|
|
|
|
CONSOLE = rich.console.Console(width=None if os.isatty(1) else 140)
|
|
torch.set_printoptions(threshold=200)
|
|
|
|
# https://gist.github.com/pypt/94d747fe5180851196eb#gistcomment-3595282
|
|
#class UniqueKeyYAMLLoader(yaml.SafeLoader):
|
|
class UniqueKeyYAMLLoader(yaml.Loader):
|
|
def construct_mapping(self, node, deep=False):
|
|
mapping = set()
|
|
for key_node, value_node in node.value:
|
|
key = self.construct_object(key_node, deep=deep)
|
|
if key in mapping:
|
|
raise KeyError(f"Duplicate {key!r} key found in YAML.")
|
|
mapping.add(key)
|
|
return super().construct_mapping(node, deep)
|
|
|
|
# load scientific notation correctly as floats and not as strings
|
|
# basically, support for the to_json filter in jinja
|
|
# https://stackoverflow.com/a/30462009
|
|
# https://github.com/yaml/pyyaml/issues/173
|
|
UniqueKeyYAMLLoader.add_implicit_resolver(
|
|
u'tag:yaml.org,2002:float',
|
|
re.compile(u'''^(?:
|
|
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
|
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
|
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
|
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
|
|[-+]?\\.(?:inf|Inf|INF)
|
|
|\\.(?:nan|NaN|NAN))$''', re.X),
|
|
list(u'-+0123456789.'))
|
|
|
|
class IgnorantActionsContainer(argparse._ActionsContainer):
|
|
"""
|
|
Ignores conflicts with
|
|
Must be enabled with ArgumentParser(conflict_handler="ignore")
|
|
"""
|
|
# https://stackoverflow.com/a/71782808
|
|
def _handle_conflict_ignore(self, action, conflicting_actions):
|
|
pass
|
|
argparse.ArgumentParser.__bases__ = (argparse._AttributeHolder, IgnorantActionsContainer)
|
|
argparse._ArgumentGroup.__bases__ = (IgnorantActionsContainer,)
|
|
|
|
@contextmanager
|
|
def ignore_action_container_conflicts(parser: Union[argparse.ArgumentParser, argparse._ArgumentGroup]):
|
|
old = parser.conflict_handler
|
|
parser.conflict_handler = "ignore"
|
|
yield
|
|
parser.conflict_handler = old
|
|
|
|
def _print_with_syntax_highlighting(language, string, indent=""):
|
|
if os.isatty(1):
|
|
string = pygments.highlight(string,
|
|
lexer = pygments.lexers.get_lexer_by_name(language),
|
|
formatter = pygments.formatters.Terminal256Formatter(style="monokai"),
|
|
)
|
|
if indent:
|
|
string = textwrap.indent(string, indent)
|
|
print(string)
|
|
|
|
def print_column_dict(data: dict, n_columns: int = 2, prefix: str=" "):
|
|
small = {k: v for k, v in data.items() if not isinstance(v, dict) and len(repr(v)) <= 40}
|
|
wide = {k: v for k, v in data.items() if not isinstance(v, dict) and len(repr(v)) > 40}
|
|
dicts = {k: v for k, v in data.items() if isinstance(v, dict)}
|
|
kw = dict(
|
|
crop = False,
|
|
overflow = "ignore",
|
|
)
|
|
if small:
|
|
CONSOLE.print(helpers.columnize_dict(small, prefix=prefix, n_columns=n_columns, sep=" "), **kw)
|
|
key_len = max(map(len, map(repr, wide.keys()))) if wide else 0
|
|
for key, val in wide.items():
|
|
CONSOLE.print(f"{prefix}{repr(key).ljust(key_len)} : {val!r},", **kw)
|
|
for key, val in dicts.items():
|
|
CONSOLE.print(f"{prefix}{key!r}: {{", **kw)
|
|
print_column_dict(val, n_columns=n_columns, prefix=prefix+" ")
|
|
CONSOLE.print(f"{prefix}}},", **kw)
|
|
|
|
|
|
M = TypeVar("M", bound=nn.Module)
|
|
DM = TypeVar("DM", bound=pl.LightningDataModule)
|
|
FitHook = Callable[[Namespace, Munch, M, pl.Trainer, DM, logging.Logger], None]
|
|
|
|
class CliInterface:
|
|
trainer_defaults: dict
|
|
|
|
def __init__(self, *, module_cls: type[M], workdir: Path, datamodule_cls: Union[list[type[DM]], type[DM], None] = None, experiment_name_prefix = "experiment"):
|
|
self.module_cls = module_cls
|
|
self.datamodule_cls = [datamodule_cls] if not isinstance(datamodule_cls, list) and datamodule_cls is not None else datamodule_cls
|
|
self.workdir = workdir
|
|
self.experiment_name_prefix = experiment_name_prefix
|
|
|
|
self.trainer_defaults = dict(
|
|
enable_model_summary = False,
|
|
)
|
|
|
|
self.pre_fit_handlers: list[FitHook] = []
|
|
self.post_fit_handlers: list[FitHook] = []
|
|
|
|
self._registered_actions : dict[str, tuple[Callable[M, None], list, dict, Optional[callable]]] = {}
|
|
self._included_in_config_template : dict[str, tuple[callable, dict]] = {}
|
|
|
|
self.register_action(_func=self.repr, help="Print str(module).", args=[])
|
|
self.register_action(_func=self.yaml, help="Print evaluated config.", args=[])
|
|
self.register_action(_func=self.hparams, help="Print hparams, like during training.", args=[])
|
|
self.register_action(_func=self.dot, help="Print graphviz graph of computation graph.", args=[
|
|
("-e", "--eval", dict(action="store_true")),
|
|
("-f", "--filter", dict(action="store_true")),
|
|
])
|
|
self.register_action(_func=self.jit, help="Print a TorchScript graph of the model", args=[])
|
|
self.register_action(_func=self.trace, help="Dump a TorchScript trace of the model.", args=[
|
|
("output_file", dict(type=Path,
|
|
help="Path to write the .pt file. Use \"-\" to instead open the trace in Netron.app")),
|
|
])
|
|
self.register_action(_func=self.onnx, help="Dump a ONNX trace of the model.", args=[
|
|
("output_file", dict(type=Path,
|
|
help="Path to write the .onnx file. Use \"-\" to instead open the onnx in Netron.app")),
|
|
])
|
|
|
|
if self.datamodule_cls:
|
|
names = [i.__name__ for i in self.datamodule_cls]
|
|
names_snake = [datamodule_name_to_snake_case(i) for i in names]
|
|
assert len(names) == len(set(names)),\
|
|
f"Datamodule names are not unique: {names!r}"
|
|
assert len(names) == len(set(names_snake)),\
|
|
f"Datamodule snake-names are not unique: {names_snake!r}"
|
|
|
|
self.register_action(_func=self.test_dataloader,
|
|
help="Benchmark the speed of the dataloader",
|
|
args=[
|
|
("datamodule", dict(type=str, default=None, nargs='?', choices=names_snake,
|
|
help="Which dataloader to test. Defaults to the first one found in config.")),
|
|
("--limit-cores", dict(type=int, default=None,
|
|
help="Limits the cpu affinity to N cores. Perfect to simulate a SLURM environ.")),
|
|
("--profile", dict(type=Path, default=None,
|
|
help="Profile using cProfile, marshaling the result to a .prof or .log file.")),
|
|
("-n", "--n-rounds", dict(type=int, default=3,
|
|
help="Number of times to read the dataloader.")),
|
|
],
|
|
conflict_handler = "ignore" if len(self.datamodule_cls) > 1 else "error",
|
|
add_argparse_args=[i.add_argparse_args for i in self.datamodule_cls],
|
|
)
|
|
|
|
|
|
# decorator
|
|
def register_pre_training_callback(self, func: FitHook):
|
|
self.pre_fit_handlers.append(func)
|
|
return func
|
|
|
|
# decorator
|
|
def register_post_training_callback(self, func: FitHook):
|
|
self.post_fit_handlers.append(func)
|
|
return func
|
|
|
|
# decorator
|
|
def register_action(self, *,
|
|
help : str,
|
|
args : list[tuple[Any, ..., dict]] = [],
|
|
_func : Optional[Callable[[Namespace, Munch, M], None]] = None,
|
|
add_argparse_args : Union[list[Callable[[ArgumentParser], ArgumentParser]], Callable[[ArgumentParser], ArgumentParser], None] = None,
|
|
**kw,
|
|
):
|
|
def wrapper(action: Callable[[Namespace, Munch, M], None]):
|
|
cli_name = action.__name__.lower().replace("_", "-")
|
|
self._registered_actions[cli_name] = (
|
|
action,
|
|
args,
|
|
kw | {"help": help},
|
|
add_argparse_args,
|
|
)
|
|
return action
|
|
if _func is not None: # shortcut
|
|
return wrapper(_func)
|
|
else:
|
|
return wrapper
|
|
|
|
def make_parser(self,
|
|
parser : ArgumentParser = None,
|
|
subparsers : _SubParsersAction = None,
|
|
add_trainer : bool = False,
|
|
) -> tuple[ArgumentParser, _SubParsersAction, _SubParsersAction]:
|
|
if parser is None:
|
|
parser = ArgumentParser()
|
|
if subparsers is None:
|
|
subparsers = parser.add_subparsers(dest="mode", required=True)
|
|
|
|
parser.add_argument("-pm", "--post-mortem", action="store_true",
|
|
help="Start a debugger if a uncaught exception is thrown.")
|
|
|
|
# Template generation and exploration
|
|
parser_template = subparsers.add_parser("template",
|
|
help="Generate or evaluate a config template")
|
|
if 1: # fold me
|
|
parser_mode_mutex = parser_template.add_mutually_exclusive_group()#(required=True)
|
|
parser_mode_mutex.add_argument("-e", "--evaluate", metavar="TEMPLATE", type=Path,
|
|
help="Read jinja2 yaml config template file, then evaluate and print it.")
|
|
parser_mode_mutex.add_argument("-p", "--parse", metavar="TEMPLATE", type=Path,
|
|
help="Read jinja2 yaml config template file, then evaluate, parse and print it.")
|
|
|
|
def pair(data: str) -> tuple[str, str]:
|
|
key, sep, value = data.partition("=")
|
|
if not sep:
|
|
if key in os.environ:
|
|
value = os.environ[key]
|
|
else:
|
|
raise ValueError(f"the variable {key!r} was not given any value, and none was found in the environment.")
|
|
elif "$" in value:
|
|
value = string.Template(value).substitute(os.environ)
|
|
return (key, value)
|
|
parser_template.add_argument("-O", dest="jinja2_variables", action="append", type=pair,
|
|
help="Variable available as string in the jinja2. (a=b). b will be expanded as an"
|
|
" env var if prefixed with $, or set equal to the env var a if =b is omitted.")
|
|
|
|
parser_template.add_argument("-s", "--strict", action="store_true",
|
|
help="Enable {% do require_defined(\"var\",var) %}".replace("%", "%%"))
|
|
parser_template.add_argument("-d", "--defined-only", action="store_true",
|
|
help="Disallow any use of undefined variables")
|
|
|
|
|
|
# Load a module
|
|
parser_module = subparsers.add_parser("module", aliases=["model"],
|
|
help="Load a config template, evaluate it and use the resulting module")
|
|
if 1: # fold me
|
|
parser_module.add_argument("module_file", type=Path,
|
|
help="Jinja2 yaml config template or pytorch-lightning .ckpt file.")
|
|
parser_module.add_argument("-O", dest="jinja2_variables", action="append", type=pair,
|
|
help="Variable available as string in the jinja2. (a=b). b will be expanded as an"
|
|
" env var if prefixed with $, or set equal to the env var a if =b is omitted.")
|
|
parser_module.add_argument("--last", action="store_true",
|
|
help="if multiple ckpt match, prefer the last one")
|
|
parser_module.add_argument("--best", action="store_true",
|
|
help="if multiple ckpt match, prefer the best one")
|
|
|
|
parser_module.add_argument("--add-shape-prehook", action="store_true",
|
|
help="Add a forward hook which prints the tensor shapes of all inputs, but not the outputs.")
|
|
parser_module.add_argument("--add-shape-hook", action="store_true",
|
|
help="Add a forward hook which prints the tensor shapes of all inputs AND outputs.")
|
|
parser_module.add_argument("--add-oob-hook", action="store_true",
|
|
help="Add a forward hook checking for INF and NaN values in inputs or outputs.")
|
|
parser_module.add_argument("--add-oob-hook-input", action="store_true",
|
|
help="Add a forward hook checking for INF and NaN values in inputs.")
|
|
parser_module.add_argument("--add-oob-hook-output", action="store_true",
|
|
help="Add a forward hook checking for INF and NaN values in outputs.")
|
|
|
|
|
|
module_actions_subparser = parser_module.add_subparsers(dest="action", required=True)
|
|
|
|
# add pluggables
|
|
for name, (action, args, kw, add_argparse_args) in self._registered_actions.items():
|
|
action_parser = module_actions_subparser.add_parser(name, **kw)
|
|
if add_argparse_args is not None and add_argparse_args:
|
|
for func in add_argparse_args if isinstance(add_argparse_args, list) else [add_argparse_args]:
|
|
action_parser = func(action_parser)
|
|
for *a, kw in args:
|
|
action_parser.add_argument(*a, **kw)
|
|
|
|
# Module: train or test
|
|
if self.datamodule_cls:
|
|
parser_trainer = module_actions_subparser.add_parser("fit", aliases=["test"],
|
|
help="Train/fit or evaluate the module with train/val or test data.")
|
|
|
|
# pl.Trainer
|
|
parser_trainer = pl.Trainer.add_argparse_args(parser_trainer)
|
|
|
|
# datamodule
|
|
parser_trainer.add_argument("datamodule", type=str, default=None, nargs='?',
|
|
choices=[datamodule_name_to_snake_case(i) for i in self.datamodule_cls],
|
|
help="Which dataloader to test. Defaults to the first one found in config.")
|
|
if len(self.datamodule_cls) > 1:
|
|
# check that none of the datamodules conflict with trainer or module
|
|
for datamodule_cls in self.datamodule_cls:
|
|
datamodule_cls.add_argparse_args(copy.deepcopy(parser_trainer)) # will raise on conflict
|
|
# Merge the datamodule options, the above sanity check makes it "okay"
|
|
with ignore_action_container_conflicts(parser_trainer):
|
|
for datamodule_cls in self.datamodule_cls:
|
|
parser_trainer = datamodule_cls.add_argparse_args(parser_trainer)
|
|
|
|
# defaults and jinja template
|
|
self._included_in_config_template.clear()
|
|
remove_options_from_parser(parser_trainer, "--logger")
|
|
parser_trainer.set_defaults(**self.trainer_defaults)
|
|
self.add_to_jinja_template("trainer", pl.Trainer, defaults=self.trainer_defaults, exclude_list={
|
|
# not yaml friendly, already covered anyway:
|
|
"logger",
|
|
"plugins",
|
|
"callbacks",
|
|
# deprecated or covered by callbacks:
|
|
"stochastic_weight_avg",
|
|
"enable_model_summary",
|
|
"track_grad_norm",
|
|
"log_gpu_memory",
|
|
})
|
|
for datamodule_cls in self.datamodule_cls:
|
|
self.add_to_jinja_template(datamodule_cls.__name__, datamodule_cls,
|
|
comment=f"select with {datamodule_name_to_snake_case(datamodule_cls)!r}")#, commented=False)
|
|
self.add_to_jinja_template("logging", logging, save_dir = "logdir", commented=False)
|
|
|
|
return parser, subparsers, module_actions_subparser
|
|
|
|
def add_to_jinja_template(self, name: str, func: callable, **kwargs):
|
|
"""
|
|
Basically a call to `make_jinja_template`.
|
|
Will ensure the keys are present in the output from `from_argparse_args`.
|
|
"""
|
|
self._included_in_config_template[name] = (func, dict(commented=True) | kwargs)
|
|
|
|
def make_jinja_template(self) -> str:
|
|
return "\n".join([
|
|
f'#!/usr/bin/env -S python {sys.argv[0]} module',
|
|
r'{% do require_defined("select", select, 0, "$SLURM_ARRAY_TASK_ID") %}{# requires jinja2.ext.do #}',
|
|
r"{% set counter = itertools.count(start=0, step=1) %}",
|
|
r"",
|
|
r"{% set hp_matrix = namespace() %}{# hyper parameter matrix #}",
|
|
r"{% set hp_matrix.my_hparam = [0] %}{##}",
|
|
r"",
|
|
r"{% for hp in cartesian_hparams(hp_matrix) %}{##}",
|
|
r"{#% for hp in ablation_hparams(hp_matrix, caartesian_keys=[]) %}{##}",
|
|
r"",
|
|
r"{% set index = next(counter) %}",
|
|
r"{% if select is not defined and index > 0 %}---{% endif %}",
|
|
r"{% if select is not defined or int(select) == index %}",
|
|
r"",
|
|
*[
|
|
func.make_jinja_template(name=name, **kwargs)
|
|
if hasattr(func, "make_jinja_template") else
|
|
param.make_jinja_template(func, name=name, **kwargs)
|
|
for name, (func, kwargs) in self._included_in_config_template.items()
|
|
],
|
|
r"{% autoescape false %}",
|
|
r'{% do require_defined("experiment_name", experiment_name, "test", strict=true) %}',
|
|
f"experiment_name: { self.experiment_name_prefix }-{{{{ experiment_name }}}}",
|
|
r'{#--#}-{{ hp.my_hparam }}',
|
|
r'{#--#}-{{ gen_run_uid(4) }} # select with -Oselect={{ index }}',
|
|
r"{% endautoescape %}",
|
|
self.module_cls.make_jinja_template(),
|
|
r"{% endif %}{# -Oselect #}",
|
|
r"",
|
|
r"{% endfor %}",
|
|
r"",
|
|
r"{% set index = next(counter) %}",
|
|
r"# number of possible 'select': {{ index }}, from 0 to {{ index-1 }}",
|
|
r"# local: for select in {0..{{ index-1 }}}; do python ... -Oselect=$select ... ; done",
|
|
r"# local: for select in {0..{{ index-1 }}}; do python -O {{ argv[0] }} model marf.yaml.j2 -Oselect=$select -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu ; done",
|
|
r"# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python ... -Oselect=\$SLURM_ARRAY_TASK_ID ...",
|
|
r"# slurm: sbatch --array=0-{{ index-1 }} runcommand.slurm python -O {{ argv[0] }} model this-file.yaml.j2 -Oselect=\$SLURM_ARRAY_TASK_ID -Oexperiment_name='{{ experiment_name }}' fit --accelerator gpu --devices -1 --strategy ddp"
|
|
])
|
|
|
|
def run(self, args=None, args_hook: Optional[Callable[[ArgumentParser, _SubParsersAction, _SubParsersAction], None]] = None):
|
|
parser, mode_subparser, action_subparser = self.make_parser()
|
|
if args_hook is not None:
|
|
args_hook(parser, mode_subparser, action_subparser)
|
|
args = parser.parse_args(args) # may exit
|
|
if os.isatty(0) and args.post_mortem:
|
|
warnings.warn("post-mortem debugging is enabled without any TTY attached. Will be ignored.")
|
|
if args.post_mortem and os.isatty(0):
|
|
try:
|
|
self.handle_args(args)
|
|
except Exception:
|
|
# print exception
|
|
sys.excepthook(*sys.exc_info())
|
|
# debug
|
|
*debug_module, debug_func = os.environ.get("PYTHONBREAKPOINT", "pdb.set_trace").split(".")
|
|
__import__(".".join(debug_module)).post_mortem()
|
|
exit(1)
|
|
else:
|
|
self.handle_args(args)
|
|
|
|
def handle_args(self, args: Namespace):
|
|
"""
|
|
May call exit()
|
|
"""
|
|
if args.mode == "template":
|
|
|
|
if args.evaluate or args.parse:
|
|
template_file = args.evaluate or args.parse
|
|
env = param.make_jinja_env(globals=param.make_jinja_globals(enable_require_defined=args.strict), allow_undef=not args.defined_only)
|
|
if str(template_file) == "-":
|
|
template = env.from_string(sys.stdin.read(), globals=dict(args.jinja2_variables or []))
|
|
else:
|
|
template = env.get_template(str(template_file.absolute()), globals=dict(args.jinja2_variables or []))
|
|
config_yaml = param.squash_newlines(template.render())#.lstrip("\n").rstrip()
|
|
if args.evaluate:
|
|
_print_with_syntax_highlighting("yaml+jinja", config_yaml)
|
|
else:
|
|
config = yaml.load(config_yaml, UniqueKeyYAMLLoader)
|
|
CONSOLE.print(config)
|
|
|
|
else:
|
|
_print_with_syntax_highlighting("yaml+jinja", self.make_jinja_template())
|
|
|
|
elif args.mode in ("module", "model"):
|
|
|
|
module: nn.Module
|
|
|
|
if not args.module_file.is_file():
|
|
matches = [*Path("logdir/tensorboard").rglob(f"*-{args.module_file}/checkpoints/*.ckpt")]
|
|
if len(matches) == 1:
|
|
args.module_file, = matches
|
|
elif len(matches) > 1:
|
|
if (args.last or args.best) and len(set(match.parent.parent.name for match in matches)) == 1:
|
|
if args.last:
|
|
args.module_file, = (match for match in matches if match.name == "last.ckpt")
|
|
elif args.best:
|
|
args.module_file, = (match for match in matches if match.name.startswith("epoch="))
|
|
else:
|
|
assert False
|
|
else:
|
|
raise ValueError("uid matches multiple paths:\n"+"\n".join(map(str, matches)))
|
|
else:
|
|
raise ValueError("path does not exist, and is not a uid")
|
|
|
|
# load module from cli args
|
|
if args.module_file.suffix == ".ckpt": # from checkpoint
|
|
# load from checkpoint
|
|
rich.print(f"Loading module from {str(args.module_file)!r}...", file=sys.stderr)
|
|
module = self.module_cls.load_from_checkpoint(args.module_file)
|
|
|
|
if (args.module_file.parent.parent / "hparams.yaml").is_file():
|
|
with (args.module_file.parent.parent / "hparams.yaml").open() as f:
|
|
config_yaml = yaml.load(f.read(), UniqueKeyYAMLLoader)["_pickled_cli_args"]["_raw_yaml"]
|
|
else:
|
|
with (args.module_file.parent.parent / "config.yaml").open() as f:
|
|
config_yaml = f.read()
|
|
|
|
config = munchify(yaml.load(config_yaml, UniqueKeyYAMLLoader) | {"_raw_yaml": config_yaml})
|
|
|
|
else: # from yaml
|
|
|
|
# read, evaluate and parse config
|
|
if args.module_file.suffix == ".j2" or str(args.module_file) == "-":
|
|
env = param.make_jinja_env()
|
|
if str(args.module_file) == "-":
|
|
template = env.from_string(sys.stdin.read(), globals=dict(args.jinja2_variables or []))
|
|
else: # jinja+yaml file
|
|
template = env.get_template(str(args.module_file.absolute()), globals=dict(args.jinja2_variables or []))
|
|
config_yaml = param.squash_newlines(template.render()).lstrip("\n").rstrip()
|
|
else: # yaml file (the git diffs in _pickled_cli_args may trigger jinja's escape sequences)
|
|
with args.module_file.open() as f:
|
|
config_yaml = f.read().lstrip("\n").rstrip()
|
|
|
|
config = yaml.load(config_yaml, UniqueKeyYAMLLoader)
|
|
|
|
if "_pickled_cli_args" in config: # hparams.yaml in tensorboard logdir
|
|
config_yaml = config["_pickled_cli_args"]["_raw_yaml"]
|
|
config = yaml.load(config_yaml, UniqueKeyYAMLLoader)
|
|
|
|
from_checkpoint: Optional[Path] = None
|
|
if (args.module_file.parent / "checkpoints").glob("*.ckpt"):
|
|
checkpoints_fnames = list((args.module_file.parent / "checkpoints").glob("*.ckpt"))
|
|
if len(checkpoints_fnames) == 1:
|
|
from_checkpoint = checkpoints_fnames[0]
|
|
elif args.last:
|
|
from_checkpoint, = (i for i in checkpoints_fnames if i.name == "last.ckpt")
|
|
elif args.best:
|
|
from_checkpoint, = (i for i in checkpoints_fnames if i.name.startswith("epoch="))
|
|
elif len(checkpoints_fnames) > 1:
|
|
rich.print(f"[yellow]WARNING:[/] {str(args.module_file.parent / 'checkpoints')!r} contains more than one checkpoint, unable to automatically load one.", file=sys.stderr)
|
|
|
|
config = munchify(config | {"_raw_yaml": config_yaml})
|
|
|
|
# Ensure date and uid to experiment name, allowing for reruns and organization
|
|
assert config.experiment_name
|
|
assert re.match(r'^.*-[0-9]{4}-[0-9]{2}-[0-9]{2}-[0-9]{4}-[a-z]{4}$', config.experiment_name),\
|
|
config.experiment_name
|
|
|
|
# init the module
|
|
if from_checkpoint:
|
|
rich.print(f"Loading module from {str(from_checkpoint)!r}...", file=sys.stderr)
|
|
module = self.module_cls.load_from_checkpoint(from_checkpoint)
|
|
else:
|
|
module = self.module_cls(**{k:v for k, v in config[self.module_cls.__name__].items() if k != "_extra"})
|
|
|
|
# optional debugging forward hooks
|
|
|
|
if args.add_shape_hook or args.add_shape_prehook:
|
|
def shape_forward_hook(is_prehook: bool, name: str, module: nn.Module, input, output=None):
|
|
def tensor_to_shape(val):
|
|
if isinstance(val, torch.Tensor):
|
|
return tuple(val.shape)
|
|
elif isinstance(val, (str, float, int)) or val is None:
|
|
return 1
|
|
else:
|
|
assert 0, (val, name)
|
|
with torch.no_grad():
|
|
rich.print(
|
|
f"{name}.forward({helpers.map_tree(tensor_to_shape, input)})"
|
|
if is_prehook else
|
|
f"{name}.forward({helpers.map_tree(tensor_to_shape, input)})"
|
|
f" -> {helpers.map_tree(tensor_to_shape, output)}"
|
|
, file=sys.stderr)
|
|
|
|
for submodule_name, submodule in module.named_modules():
|
|
if submodule_name:
|
|
submodule_name = f"{module.__class__.__qualname__}.{submodule_name}"
|
|
else:
|
|
submodule_name = f"{module.__class__.__qualname__}"
|
|
if args.add_shape_prehook:
|
|
submodule.register_forward_pre_hook(partial(shape_forward_hook, True, submodule_name))
|
|
if args.add_shape_hook:
|
|
submodule.register_forward_hook(partial(shape_forward_hook, False, submodule_name))
|
|
|
|
if args.add_oob_hook or args.add_oob_hook_input or args.add_oob_hook_output:
|
|
def oob_forward_hook(name: str, module: nn.Module, input, output):
|
|
def raise_if_oob(key, val):
|
|
if isinstance(val, collections.abc.Mapping):
|
|
for k, subval in val.items():
|
|
raise_if_oob(f"{key}[{k!r}]", subval)
|
|
elif isinstance(val, (tuple, list)):
|
|
for i, subval in enumerate(val):
|
|
raise_if_oob(f"{key}[{i}]", subval)
|
|
elif isinstance(val, torch.Tensor):
|
|
assert not torch.isinf(val).any(), \
|
|
f"INFs found in {key}"
|
|
assert not val.isnan().any(), \
|
|
f"NaNs found in {key}"
|
|
elif isinstance(val, (str, float, int)):
|
|
pass
|
|
elif val is None:
|
|
warnings.warn(f"None found in {key}")
|
|
else:
|
|
assert False, val
|
|
with torch.no_grad():
|
|
if args.add_oob_hook or args.add_oob_hook_input:
|
|
raise_if_oob(f"{name}.forward input", input)
|
|
if args.add_oob_hook or args.add_oob_hook_output:
|
|
raise_if_oob(f"{name}.forward output", output)
|
|
|
|
for submodule_name, submodule in module.named_modules():
|
|
submodule.register_forward_hook(partial(oob_forward_hook,
|
|
f"{module.__class__.__qualname__}.{submodule_name}"
|
|
if submodule_name else
|
|
f"{module.__class__.__qualname__}"
|
|
))
|
|
|
|
# Ensure all the top-level config keys are there
|
|
for key in self._included_in_config_template.keys():
|
|
if key in (i.__name__ for i in self.datamodule_cls):
|
|
continue
|
|
if key not in config or config[key] is None:
|
|
config[key] = {}
|
|
|
|
# Run registered action
|
|
if args.action in self._registered_actions:
|
|
action, *_ = self._registered_actions[args.action]
|
|
action(args, config, module)
|
|
elif args.action in ("fit", "test") and self.datamodule_cls is not None:
|
|
self.fit(args, config, module)
|
|
else:
|
|
raise ValueError(f"{args.mode=}, {args.action=}")
|
|
|
|
else:
|
|
raise ValueError(f"{args.mode=}")
|
|
|
|
def get_datamodule_cls_from_config(self, args: Namespace, config: Munch) -> DM:
|
|
assert self.datamodule_cls
|
|
cli = getattr(args, "datamodule", None)
|
|
datamodule_cls: pl.LightningDataModule
|
|
if cli is not None:
|
|
datamodule_cls, = (i for i in self.datamodule_cls if datamodule_name_to_snake_case(i) == cli)
|
|
else:
|
|
datamodules = {
|
|
cls.__name__: cls
|
|
for cls in self.datamodule_cls
|
|
}
|
|
for key in config.keys():
|
|
if key in datamodules:
|
|
datamodule_cls = datamodules[key]
|
|
break
|
|
else:
|
|
datamodule_cls = self.datamodule_cls[0]
|
|
warnings.warn(f"None of the following datamodules were found in config: {set(datamodules.keys())!r}. {datamodule_cls.__name__!r} was chosen as the default.")
|
|
|
|
return datamodule_cls
|
|
|
|
def init_datamodule_cls_from_config(self, args: Namespace, config: Munch) -> DM:
|
|
datamodule_cls = self.get_datamodule_cls_from_config(args, config)
|
|
return datamodule_cls.from_argparse_args(args, **(config.get(datamodule_cls.__name__) or {}))
|
|
|
|
|
|
# Module actions
|
|
|
|
def repr(self, args: Namespace, config: Munch, module: M):
|
|
rich.print(module)
|
|
|
|
def yaml(self, args: Namespace, config: Munch, module: M):
|
|
_print_with_syntax_highlighting("yaml+jinja", config["_raw_yaml"])
|
|
|
|
def dot(self, args: Namespace, config: Munch, module: M):
|
|
module.train(not args.eval)
|
|
assert not args.filter, "not implemented! pipe it through examples/scripts/filter_dot.py in the meanwhile"
|
|
|
|
example_input_array = module.example_input_array
|
|
assert example_input_array is not None, f"{module.__class__.__qualname__}.example_input_array=None"
|
|
assert isinstance(example_input_array, (tuple, dict, torch.Tensor)), type(example_input_array)
|
|
|
|
def set_requires_grad(val):
|
|
if isinstance(val, torch.Tensor):
|
|
val.requires_grad = True
|
|
return val
|
|
|
|
with torch.enable_grad():
|
|
outputs = module(*helpers.map_tree(set_requires_grad, example_input_array))
|
|
|
|
dot = torchviz.make_dot(outputs, params=dict(module.named_parameters()), show_attrs=False, show_saved=False)
|
|
_print_with_syntax_highlighting("dot", str(dot))
|
|
|
|
def jit(self, args: Namespace, config: Munch, module: M):
|
|
example_input_array = module.example_input_array
|
|
assert example_input_array is not None, f"{module.__class__.__qualname__}.example_input_array=None"
|
|
assert isinstance(example_input_array, (tuple, dict, torch.Tensor)), type(example_input_array)
|
|
trace = torch.jit.trace_module(module, {"forward": example_input_array})
|
|
_print_with_syntax_highlighting("python", str(trace.inlined_graph))
|
|
|
|
def trace(self, args: Namespace, config: Munch, module: M):
|
|
if isinstance(module, pl.LightningModule):
|
|
trace = module.to_torchscript(method="trace")
|
|
else:
|
|
example_input_array = module.example_input_array
|
|
assert example_input_array is not None, f"{module.__class__.__qualname__}.example_input_array is None"
|
|
assert isinstance(module, torch.Module)
|
|
trace = torch.jit.trace_module(module, {"forward": example_input_array})
|
|
|
|
use_netron = str(args.output_file) == "-"
|
|
trace_f = io.BytesIO() if use_netron else args.output_file
|
|
|
|
torch.jit.save(trace, trace_f)
|
|
|
|
if use_netron:
|
|
open_in_netron(f"{self.module_cls.__name__}.pt", trace_f.getvalue())
|
|
|
|
def onnx(self, args: Namespace, config: Munch, module: M):
|
|
example_input_array = module.example_input_array
|
|
assert example_input_array is not None, f"{module.__class__.__qualname__}.example_input_array=None"
|
|
assert isinstance(example_input_array, (tuple, dict, torch.Tensor)), type(example_input_array)
|
|
|
|
use_netron = str(args.output_file) == "-"
|
|
onnx_f = io.BytesIO() if use_netron else args.output_file
|
|
|
|
torch.onnx.export(module,
|
|
tuple(example_input_array),
|
|
onnx_f,
|
|
export_params = True,
|
|
opset_version = 17,
|
|
do_constant_folding = True,
|
|
input_names = ["input"],
|
|
output_names = ["output"],
|
|
dynamic_axes = {
|
|
"input" : {0 : "batch_size"},
|
|
"output" : {0 : "batch_size"},
|
|
},
|
|
)
|
|
|
|
if use_netron:
|
|
open_in_netron(f"{self.module_cls.__name__}.onnx", onnx_f.getvalue())
|
|
|
|
def hparams(self, args: Namespace, config: Munch, module: M):
|
|
assert isinstance(module, self.module_cls)
|
|
print(f"{self.module_cls.__qualname__} hparams:")
|
|
print_column_dict(map_type_to_repr(module.hparams, nn.Module, lambda t: f"{t.__class__.__qualname__}"), 3)
|
|
|
|
|
|
def fit(self, args: Namespace, config: Munch, module: M):
|
|
is_rank_zero = pl.utilities.rank_zero_only.rank == 0
|
|
|
|
metric_prefix = f"{module.__class__.__name__}.validation_step/"
|
|
|
|
pl_callbacks = [
|
|
pl.callbacks.LearningRateMonitor(log_momentum=True),
|
|
pl.callbacks.EarlyStopping(monitor=metric_prefix+getattr(module, "metric_early_stop", "loss"), patience=200, check_on_train_epoch_end=False, verbose=True),
|
|
pl.callbacks.ModelCheckpoint(monitor=metric_prefix+getattr(module, "metric_best_model", "loss"), mode="min", save_top_k=1, save_last=True),
|
|
logging.ModelOutputMonitor(),
|
|
logging.EpochTimeMonitor(),
|
|
(pl.callbacks.RichModelSummary if os.isatty(1) else pl.callbacks.ModelSummary)(max_depth=30),
|
|
logging.PsutilMonitor(),
|
|
]
|
|
if os.isatty(1):
|
|
pl_callbacks.append( pl.callbacks.RichProgressBar() )
|
|
|
|
trainer: pl.Trainer
|
|
logger = logging.make_logger(config.experiment_name, config.trainer.get("default_root_dir", args.default_root_dir or self.workdir), **config.logging)
|
|
trainer = pl.Trainer.from_argparse_args(args, logger=logger, callbacks=pl_callbacks, **config.trainer)
|
|
|
|
datamodule = self.init_datamodule_cls_from_config(args, config)
|
|
|
|
for f in self.pre_fit_handlers:
|
|
print(f"pre-train hook {f.__name__!r}...")
|
|
f(args, config, module, trainer, datamodule, logger)
|
|
|
|
# print and log hparams/config
|
|
if 1: # fold me
|
|
if is_rank_zero:
|
|
CONSOLE.print(f"Experiment name: {config.experiment_name!r}", soft_wrap=False, crop=False, no_wrap=False, overflow="ignore")
|
|
|
|
# parser.args and sys.argv
|
|
pickled_cli_args = dict(
|
|
sys_argv = sys.argv,
|
|
parser_args = args.__dict__,
|
|
config = config.copy(),
|
|
_raw_yaml = config["_raw_yaml"],
|
|
)
|
|
del pickled_cli_args["config"]["_raw_yaml"]
|
|
for k,v in pickled_cli_args["parser_args"].items():
|
|
if isinstance(v, Path):
|
|
pickled_cli_args["parser_args"][k] = str(v)
|
|
|
|
# trainer
|
|
params_trainer = inspect.signature(pl.Trainer.__init__).parameters
|
|
trainer_hparams = vars(pl.Trainer.parse_argparser(args))
|
|
trainer_hparams = { name: trainer_hparams[name] for name in params_trainer if name in trainer_hparams }
|
|
if is_rank_zero:
|
|
print("pl.Trainer hparams:")
|
|
print_column_dict(trainer_hparams, 3)
|
|
pickled_cli_args.update(trainer_hparams=trainer_hparams)
|
|
|
|
# module
|
|
assert isinstance(module, self.module_cls)
|
|
if is_rank_zero:
|
|
print(f"{self.module_cls.__qualname__} hparams:")
|
|
print_column_dict(map_type_to_repr(module.hparams, nn.Module, lambda t: f"{t.__class__.__qualname__}"), 3)
|
|
pickled_cli_args.update(module_hparams={
|
|
k : v
|
|
for k, v in module.hparams.items()
|
|
if k != "_raw_yaml"
|
|
})
|
|
|
|
# module extra state, like autodecoder uids
|
|
for submodule_name, submodule in module.named_modules():
|
|
if not submodule_name:
|
|
submodule_name = module.__class__.__qualname__
|
|
else:
|
|
submodule_name = module.__class__.__qualname__ + "." + submodule_name
|
|
try:
|
|
state = submodule.get_extra_state()
|
|
except RuntimeError:
|
|
continue
|
|
if "extra_state" not in pickled_cli_args:
|
|
pickled_cli_args["extra_state"] = {}
|
|
pickled_cli_args["extra_state"][submodule_name] = state
|
|
|
|
# datamodule
|
|
if self.datamodule_cls:
|
|
assert datamodule is not None and any(isinstance(datamodule, i) for i in self.datamodule_cls), datamodule
|
|
for datamodule_cls in self.datamodule_cls:
|
|
params_d = inspect.signature(datamodule_cls.__init__).parameters
|
|
assert {"self"} == set(params_trainer).intersection(params_d), \
|
|
f"trainer and datamodule has overlapping params: {set(params_trainer).intersection(params_d) - {'self'}}"
|
|
|
|
if is_rank_zero:
|
|
print(f"{datamodule.__class__.__qualname__} hparams:")
|
|
print_column_dict(datamodule.hparams)
|
|
pickled_cli_args.update(datamodule_hparams=dict(datamodule.hparams))
|
|
|
|
# logger
|
|
if logger is not None:
|
|
print(f"{logger.__class__.__qualname__} hparams:")
|
|
print_column_dict(config.logging)
|
|
pickled_cli_args.update(logger_hparams = {"_class": logger.__class__.__name__} | config.logging)
|
|
|
|
# host info
|
|
def cmd(cmd: Union[str, list[str]]) -> str:
|
|
if isinstance(cmd, str):
|
|
cmd = shlex.split(cmd)
|
|
if shutil.which(cmd[0]):
|
|
try:
|
|
return subprocess.run(cmd,
|
|
capture_output=True,
|
|
check=True,
|
|
text=True,
|
|
).stdout.strip()
|
|
except subprocess.CalledProcessError as e:
|
|
warnings.warn(f"{e.__class__.__name__}: {e}")
|
|
return f"{e.__class__.__name__}: {e}\n{e.output = }\n{e.stderr = }"
|
|
else:
|
|
warnings.warn(f"command {cmd[0]!r} not found")
|
|
return f"*command {cmd[0]!r} not found*"
|
|
|
|
pickled_cli_args.update(host = dict(
|
|
platform = textwrap.dedent(f"""
|
|
{platform.architecture() = }
|
|
{platform.java_ver() = }
|
|
{platform.libc_ver() = }
|
|
{platform.mac_ver() = }
|
|
{platform.machine() = }
|
|
{platform.node() = }
|
|
{platform.platform() = }
|
|
{platform.processor() = }
|
|
{platform.python_branch() = }
|
|
{platform.python_build() = }
|
|
{platform.python_compiler() = }
|
|
{platform.python_implementation() = }
|
|
{platform.python_revision() = }
|
|
{platform.python_version() = }
|
|
{platform.release() = }
|
|
{platform.system() = }
|
|
{platform.uname() = }
|
|
{platform.version() = }
|
|
""".rstrip()).lstrip(),
|
|
cuda = dict(
|
|
gpus = [
|
|
torch.cuda.get_device_name(i)
|
|
for i in range(torch.cuda.device_count())
|
|
],
|
|
available = torch.cuda.is_available(),
|
|
version = torch.version.cuda,
|
|
),
|
|
hostname = cmd("hostname --fqdn"),
|
|
cwd = os.getcwd(),
|
|
date = datetime.now().astimezone().isoformat(),
|
|
date_utc = datetime.utcnow().isoformat(),
|
|
ifconfig = cmd("ifconfig"),
|
|
lspci = cmd("lspci"),
|
|
lsusb = cmd("lsusb"),
|
|
lsblk = cmd("lsblk"),
|
|
mount = cmd("mount"),
|
|
environ = os.environ.copy(),
|
|
vcs = f"commit {cmd('git rev-parse HEAD')}\n{cmd('git status')}\n{cmd('git diff --stat --patch HEAD')}",
|
|
venv_pip = cmd("pip list --format=freeze"),
|
|
venv_conda = cmd("conda list"),
|
|
venv_poetry = cmd("poetry show -t"),
|
|
gpus = [i.split(", ") for i in cmd("nvidia-smi --query-gpu=index,name,memory.total,driver_version,uuid --format=csv").splitlines()],
|
|
))
|
|
|
|
if logger is not None:
|
|
logging.log_config(logger, _pickled_cli_args=pickled_cli_args)
|
|
|
|
warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning)
|
|
|
|
if __debug__ and is_rank_zero:
|
|
warnings.warn("You're running python with assertions active. Enable optimizations with `python -O` for improved performance.")
|
|
|
|
# train
|
|
|
|
t_begin = datetime.now()
|
|
if args.action == "fit":
|
|
trainer.fit(module, datamodule)
|
|
elif args.action == "test":
|
|
trainer.test(module, datamodule)
|
|
else:
|
|
raise ValueError(f"{args.mode=}, {args.action=}")
|
|
|
|
if not is_rank_zero:
|
|
return
|
|
|
|
t_end = datetime.now()
|
|
print(f"Training time: {t_end - t_begin}")
|
|
|
|
for f in self.post_fit_handlers:
|
|
print(f"post-train hook {f.__name__!r}...")
|
|
try:
|
|
f(args, config, module, trainer, datamodule, logger)
|
|
except Exception:
|
|
traceback.print_exc()
|
|
|
|
rich.print(f"Experiment name: {config.experiment_name!r}")
|
|
rich.print(f"Best model path: {helpers.make_relative(trainer.checkpoint_callback.best_model_path).__str__()!r}")
|
|
rich.print(f"Last model path: {helpers.make_relative(trainer.checkpoint_callback.last_model_path).__str__()!r}")
|
|
|
|
def test_dataloader(self, args: Namespace, config: Munch, module: M):
|
|
# limit CPU affinity
|
|
if args.limit_cores is not None:
|
|
# https://stackoverflow.com/a/40856471
|
|
p = psutil.Process()
|
|
assert len(p.cpu_affinity()) >= args.limit_cores
|
|
cpus = list(range(args.limit_cores))
|
|
p.cpu_affinity(cpus)
|
|
print("Process limited to CPUs", cpus)
|
|
|
|
datamodule = self.init_datamodule_cls_from_config(args, config)
|
|
|
|
# setup
|
|
rich.print(f"Setup {datamodule.__class__.__qualname__}...")
|
|
datamodule.prepare_data()
|
|
datamodule.setup("fit")
|
|
try:
|
|
train = datamodule.train_dataloader()
|
|
except (MisconfigurationException, NotImplementedError):
|
|
train = None
|
|
try:
|
|
val = datamodule.val_dataloader()
|
|
except (MisconfigurationException, NotImplementedError):
|
|
val = None
|
|
try:
|
|
test = datamodule.test_dataloader()
|
|
except (MisconfigurationException, NotImplementedError):
|
|
test = None
|
|
|
|
# inspect
|
|
rich.print("batch[0] = ", end="")
|
|
rich.pretty.pprint(
|
|
map_type_to_repr(
|
|
next(iter(train)),
|
|
torch.Tensor,
|
|
lambda x: f"Tensor(..., shape={x.shape}, dtype={x.dtype}, device={x.device})",
|
|
),
|
|
indent_guides = False,
|
|
)
|
|
|
|
if args.profile is not None:
|
|
import cProfile
|
|
profile = cProfile.Profile()
|
|
profile.enable()
|
|
|
|
# measure
|
|
n_train, td_train = 0, 0
|
|
n_val, td_val = 0, 0
|
|
n_test, td_test = 0, 0
|
|
try:
|
|
for i in range(args.n_rounds):
|
|
print(f"Round {i+1} of {args.n_rounds}")
|
|
if train is not None:
|
|
epoch = time.perf_counter_ns()
|
|
n_train += sum(1 for _ in tqdm(train, desc=f"train {i+1}/{args.n_rounds}"))
|
|
td_train += time.perf_counter_ns() - epoch
|
|
if val is not None:
|
|
epoch = time.perf_counter_ns()
|
|
n_val += sum(1 for _ in tqdm(val, desc=f"val {i+1}/{args.n_rounds}"))
|
|
td_val += time.perf_counter_ns() - epoch
|
|
if test is not None:
|
|
epoch = time.perf_counter_ns()
|
|
n_test += sum(1 for _ in tqdm(test, desc=f"train {i+1}/{args.n_rounds}"))
|
|
td_test += time.perf_counter_ns() - epoch
|
|
except KeyboardInterrupt:
|
|
rich.print("Recieved a `KeyboardInterrupt`...")
|
|
|
|
if args.profile is not None:
|
|
profile.disable()
|
|
if args.profile != "-":
|
|
profile.dump_stats(args.profile)
|
|
profile.print_stats("tottime")
|
|
|
|
# summary
|
|
for label, data, n, td in [
|
|
("train", train, n_train, td_train),
|
|
("val", val, n_val, td_val),
|
|
("test", test, n_test, td_test),
|
|
]:
|
|
if not n: continue
|
|
if data is not None:
|
|
print(f"{label}:",
|
|
f" - per epoch: {td / args.n_rounds * 1e-9 :11.6f} s",
|
|
f" - per batch: {td / n * 1e-9 :11.6f} s",
|
|
f" - batches/s: {n / (td * 1e-9):11.6f}",
|
|
sep="\n")
|
|
|
|
datamodule.teardown("fit")
|
|
|
|
|
|
|
|
# helpers:
|
|
|
|
def open_in_netron(filename: str, data: bytes, *, timeout: float = 10):
|
|
# filename is only used to determine the filetype
|
|
url = serve_once_in_background(
|
|
data,
|
|
mime_type = "application/octet-stream",
|
|
timeout = timeout,
|
|
port = gen_random_port(),
|
|
)
|
|
url = f"https://netron.app/?url={urllib.parse.quote(url)}{filename}"
|
|
print("Open in Netron:", url)
|
|
webbrowser.get("firefox").open_new_tab(url)
|
|
if timeout:
|
|
time.sleep(timeout)
|
|
|
|
def remove_options_from_parser(parser: ArgumentParser, *options: str):
|
|
options = set(options)
|
|
# https://stackoverflow.com/questions/32807319/disable-remove-argument-in-argparse/36863647#36863647
|
|
for action in parser._actions:
|
|
if action.option_strings:
|
|
option = action.option_strings[0]
|
|
if option in options:
|
|
parser._handle_conflict_resolve(None, [(option, action)])
|
|
|
|
def map_type_to_repr(batch, type_match: type, repr_func: callable):
|
|
def mapper(value):
|
|
if isinstance(value, type_match):
|
|
return helpers.CustomRepr(repr_func(value))
|
|
else:
|
|
return value
|
|
return helpers.map_tree(mapper, batch)
|
|
|
|
def datamodule_name_to_snake_case(datamodule: Union[str, type[DM]]) -> str:
|
|
if not isinstance(datamodule, str):
|
|
datamodule = datamodule.__name__
|
|
datamodule = datamodule.replace("DataModule", "Datamodule")
|
|
if datamodule != "Datamodule":
|
|
datamodule = datamodule.removesuffix("Datamodule")
|
|
return camel_to_snake_case(datamodule, sep="-", join_abbreviations=True)
|