Files
marf/ifield/param.py
2025-01-09 15:43:11 +01:00

232 lines
8.6 KiB
Python

from .utils.helpers import compose, elementwise_max
from datetime import datetime
from torch import nn
from typing import Any, Literal, Iterable, Union, Callable, Optional
import inspect
import jinja2
import json
import os
import random
import re
import shlex
import string
import sys
import time
import typing
import warnings
import yaml
_UNDEFINED = " I AM UNDEFINED "
def _yaml_encode_value(val) -> str:
if isinstance(val, tuple):
val = list(val)
elif isinstance(val, set):
val = list(val)
if isinstance(val, list):
return json.dumps(val)
elif isinstance(val, dict):
return json.dumps(val)
else:
return yaml.dump(val).removesuffix("\n...\n").rstrip("\n")
def _raise(val: Union[Exception, str]):
if isinstance(val, str):
val = jinja2.TemplateError(val)
raise val
def make_jinja_globals(*, enable_require_defined: bool) -> dict:
import builtins
import functools
import itertools
import operator
import json
def require_defined(name, value, *defaults, failed: bool = False, strict: bool=False, exchaustive=False):
if not defaults:
raise ValueError("`require_defined` requires at least one valid value provided")
if jinja2.is_undefined(value):
assert value._undefined_name == name, \
f"Name mismatch: {value._undefined_name=}, {name=}"
if failed or jinja2.is_undefined(value):
if enable_require_defined or strict:
raise ValueError(
f"Required variable {name!r} "
f"is {'incorrect' if failed else 'undefined'}! "
f"Try providing:\n" + "\n".join(
f"-O{shlex.quote(name)}={shlex.quote(str(default))}"
for default in defaults
)
)
else:
warnings.warn(
f"Required variable {name!r} "
f"is {'incorrect' if failed else 'undefined'}! "
f"Try providing:\n" + "\n".join(
f"-O{shlex.quote(name)}={shlex.quote(str(default))}"
for default in defaults
)
)
if exchaustive and not jinja2.is_undefined(value) and value not in defaults:
raise ValueError(
f"Variable {name!r} not in list of allowed values: {defaults!r}"
)
def gen_run_uid(n: int, _choice = random.Random(time.time_ns()).choice):
"""
generates a UID for the experiment run, nice for regexes, grepping and timekeeping.
"""
# we have _choice, since most likely, pl.seed_everything has been run by this point
# we store it as a default parameter to reuse it, on the off-chance of two calls to this function being run withion the same ns
code = ''.join(_choice(string.ascii_lowercase) for _ in range(n))
return f"{datetime.now():%Y-%m-%d-%H%M}-{code}"
return f"{datetime.now():%Y%m%d-%H%M}-{code}"
def cartesian_hparams(_map=None, **kw: dict[str, list]) -> Iterable[jinja2.utils.Namespace]:
"Use this to bypass the common error 'SyntaxError: too many statically nested blocks'"
if isinstance(_map, jinja2.utils.Namespace):
kw = _map._Namespace__attrs | kw
elif isinstance(_map, dict):
kw = _map._Namespace__attrs | kw
keys, vals = zip(*kw.items())
for i in itertools.product(*vals):
yield jinja2.utils.Namespace(zip(keys, i))
def ablation_hparams(_map=None, *, caartesian_keys: list[str] = None, **kw: dict[str, list]) -> Iterable[jinja2.utils.Namespace]:
"Use this to bypass the common error 'SyntaxError: too many statically nested blocks'"
if isinstance(_map, jinja2.utils.Namespace):
kw = _map._Namespace__attrs | kw
elif isinstance(_map, dict):
kw = _map._Namespace__attrs | kw
keys = list(kw.keys())
caartesian_keys = [k for k in keys if k in caartesian_keys] if caartesian_keys else []
ablation_keys = [k for k in keys if k not in caartesian_keys]
caartesian_vals = list(map(kw.__getitem__, caartesian_keys))
ablation_vals = list(map(kw.__getitem__, ablation_keys))
for base_vals in itertools.product(*caartesian_vals):
base = list(itertools.chain(zip(caartesian_keys, base_vals), zip(ablation_keys, [i[0] for i in ablation_vals])))
yield jinja2.utils.Namespace(base)
for ablation_key, ablation_val in zip(ablation_keys, ablation_vals):
for val in ablation_val[1:]:
yield jinja2.utils.Namespace(base, **{ablation_key: val}) # ablation variation
return {
**locals(),
**vars(builtins),
"argv": sys.argv,
"raise": _raise,
}
def make_jinja_env(globals = make_jinja_globals(enable_require_defined=True), allow_undef=False) -> jinja2.Environment:
env = jinja2.Environment(
loader = jinja2.FileSystemLoader([os.getcwd(), "/"], followlinks=True),
autoescape = False,
trim_blocks = True,
lstrip_blocks = True,
undefined = jinja2.Undefined if allow_undef else jinja2.StrictUndefined,
extensions = [
"jinja2.ext.do", # statements with side-effects
"jinja2.ext.loopcontrols", # break and continue
],
)
env.globals.update(globals)
env.filters.update({
"defined": lambda x: _raise(f"{x._undefined_name!r} is not defined!") if jinja2.is_undefined(x) else x,
"repr": repr,
"to_json": json.dumps,
"bool": lambda x: json.dumps(bool(x)),
"int": lambda x: json.dumps(int(x)),
"float": lambda x: json.dumps(float(x)),
"str": lambda x: json.dumps(str(x)),
})
return env
def list_func_params(func: callable, exclude_list: set[str], defaults: dict={}) -> Iterable[tuple[str, Any, str]]:
signature = inspect.signature(func)
for i, (k, v) in enumerate(signature.parameters.items()):
if not i and k in {"self", "cls"}:
continue
if k in exclude_list:
continue
if k.startswith("_"):
continue
if v.kind is v.VAR_POSITIONAL or v.kind is v.VAR_KEYWORD:
continue
has_default = not defaults.get(k, v.default) is v.empty
has_annotation = not v.annotation is v.empty
allowed_literals = f"{{{', '.join(map(_yaml_encode_value, typing.get_args(v.annotation)))}}}" \
if typing.get_origin(v.annotation) is Literal else None
assert has_annotation, f"param {k!r} has no type annotation"
yield (
k,
defaults.get(k, v.default) if has_default else _UNDEFINED,
f"in {allowed_literals}" if allowed_literals else typing._type_repr(v.annotation),
)
@compose("\n".join)
def make_jinja_template(
network_cls: nn.Module,
*,
exclude_list: set[str] = set(),
defaults: dict[str, Any]={},
top_level: bool = True,
commented: bool = False,
name=None,
comment: Optional[str] = None,
special_encoders: dict[str, Callable[[Any], str]]={},
) -> str:
c = "#" if commented else ""
if name is None:
name = network_cls.__name__
if comment is not None:
if "\n" in comment:
raise ValueError("newline in jinja template comment is not allowed")
hparams = [*list_func_params(network_cls, exclude_list, defaults=defaults)]
if not hparams:
if top_level:
yield f"{name}:"
else:
yield f" # {name}:"
return
encoded_hparams = [
(key, _yaml_encode_value(value) if value is not _UNDEFINED else "", comment)
if key not in special_encoders else
(key, special_encoders[key](value) if value is not _UNDEFINED else "", comment)
for key, value, comment in hparams
]
ml_key, ml_value = elementwise_max(
(
len(key),
len(value),
)
for key, value, comment in encoded_hparams
)
if top_level:
yield f"{name}:" if not comment else f"{name}: # {comment}"
else:
yield f" # {name}:" if not comment else f" # {name}: # {comment}"
for key, value, comment in encoded_hparams:
if key in exclude_list:
continue
pad_key = ml_key - len(key)
pad_value = ml_value - len(value)
yield f" {c}{key}{' '*pad_key} : {value}{' '*pad_value} # {comment}"
yield ""
# helpers:
def squash_newlines(data: str) -> str:
return re.sub(r'\n\n\n+', '\n\n', data)