Add code
This commit is contained in:
231
ifield/param.py
Normal file
231
ifield/param.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from .utils.helpers import compose, elementwise_max
|
||||
from datetime import datetime
|
||||
from torch import nn
|
||||
from typing import Any, Literal, Iterable, Union, Callable, Optional
|
||||
import inspect
|
||||
import jinja2
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shlex
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
import typing
|
||||
import warnings
|
||||
import yaml
|
||||
|
||||
_UNDEFINED = " I AM UNDEFINED "
|
||||
|
||||
def _yaml_encode_value(val) -> str:
|
||||
if isinstance(val, tuple):
|
||||
val = list(val)
|
||||
elif isinstance(val, set):
|
||||
val = list(val)
|
||||
if isinstance(val, list):
|
||||
return json.dumps(val)
|
||||
elif isinstance(val, dict):
|
||||
return json.dumps(val)
|
||||
else:
|
||||
return yaml.dump(val).removesuffix("\n...\n").rstrip("\n")
|
||||
|
||||
def _raise(val: Union[Exception, str]):
|
||||
if isinstance(val, str):
|
||||
val = jinja2.TemplateError(val)
|
||||
raise val
|
||||
|
||||
def make_jinja_globals(*, enable_require_defined: bool) -> dict:
|
||||
import builtins
|
||||
import functools
|
||||
import itertools
|
||||
import operator
|
||||
import json
|
||||
|
||||
def require_defined(name, value, *defaults, failed: bool = False, strict: bool=False, exchaustive=False):
|
||||
if not defaults:
|
||||
raise ValueError("`require_defined` requires at least one valid value provided")
|
||||
if jinja2.is_undefined(value):
|
||||
assert value._undefined_name == name, \
|
||||
f"Name mismatch: {value._undefined_name=}, {name=}"
|
||||
if failed or jinja2.is_undefined(value):
|
||||
if enable_require_defined or strict:
|
||||
raise ValueError(
|
||||
f"Required variable {name!r} "
|
||||
f"is {'incorrect' if failed else 'undefined'}! "
|
||||
f"Try providing:\n" + "\n".join(
|
||||
f"-O{shlex.quote(name)}={shlex.quote(str(default))}"
|
||||
for default in defaults
|
||||
)
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Required variable {name!r} "
|
||||
f"is {'incorrect' if failed else 'undefined'}! "
|
||||
f"Try providing:\n" + "\n".join(
|
||||
f"-O{shlex.quote(name)}={shlex.quote(str(default))}"
|
||||
for default in defaults
|
||||
)
|
||||
)
|
||||
if exchaustive and not jinja2.is_undefined(value) and value not in defaults:
|
||||
raise ValueError(
|
||||
f"Variable {name!r} not in list of allowed values: {defaults!r}"
|
||||
)
|
||||
|
||||
def gen_run_uid(n: int, _choice = random.Random(time.time_ns()).choice):
|
||||
"""
|
||||
generates a UID for the experiment run, nice for regexes, grepping and timekeeping.
|
||||
"""
|
||||
# we have _choice, since most likely, pl.seed_everything has been run by this point
|
||||
# we store it as a default parameter to reuse it, on the off-chance of two calls to this function being run withion the same ns
|
||||
code = ''.join(_choice(string.ascii_lowercase) for _ in range(n))
|
||||
return f"{datetime.now():%Y-%m-%d-%H%M}-{code}"
|
||||
return f"{datetime.now():%Y%m%d-%H%M}-{code}"
|
||||
|
||||
def cartesian_hparams(_map=None, **kw: dict[str, list]) -> Iterable[jinja2.utils.Namespace]:
|
||||
"Use this to bypass the common error 'SyntaxError: too many statically nested blocks'"
|
||||
if isinstance(_map, jinja2.utils.Namespace):
|
||||
kw = _map._Namespace__attrs | kw
|
||||
elif isinstance(_map, dict):
|
||||
kw = _map._Namespace__attrs | kw
|
||||
keys, vals = zip(*kw.items())
|
||||
for i in itertools.product(*vals):
|
||||
yield jinja2.utils.Namespace(zip(keys, i))
|
||||
|
||||
def ablation_hparams(_map=None, *, caartesian_keys: list[str] = None, **kw: dict[str, list]) -> Iterable[jinja2.utils.Namespace]:
|
||||
"Use this to bypass the common error 'SyntaxError: too many statically nested blocks'"
|
||||
if isinstance(_map, jinja2.utils.Namespace):
|
||||
kw = _map._Namespace__attrs | kw
|
||||
elif isinstance(_map, dict):
|
||||
kw = _map._Namespace__attrs | kw
|
||||
keys = list(kw.keys())
|
||||
|
||||
caartesian_keys = [k for k in keys if k in caartesian_keys] if caartesian_keys else []
|
||||
ablation_keys = [k for k in keys if k not in caartesian_keys]
|
||||
caartesian_vals = list(map(kw.__getitem__, caartesian_keys))
|
||||
ablation_vals = list(map(kw.__getitem__, ablation_keys))
|
||||
|
||||
for base_vals in itertools.product(*caartesian_vals):
|
||||
base = list(itertools.chain(zip(caartesian_keys, base_vals), zip(ablation_keys, [i[0] for i in ablation_vals])))
|
||||
yield jinja2.utils.Namespace(base)
|
||||
for ablation_key, ablation_val in zip(ablation_keys, ablation_vals):
|
||||
for val in ablation_val[1:]:
|
||||
yield jinja2.utils.Namespace(base, **{ablation_key: val}) # ablation variation
|
||||
|
||||
return {
|
||||
**locals(),
|
||||
**vars(builtins),
|
||||
"argv": sys.argv,
|
||||
"raise": _raise,
|
||||
}
|
||||
|
||||
def make_jinja_env(globals = make_jinja_globals(enable_require_defined=True), allow_undef=False) -> jinja2.Environment:
|
||||
env = jinja2.Environment(
|
||||
loader = jinja2.FileSystemLoader([os.getcwd(), "/"], followlinks=True),
|
||||
autoescape = False,
|
||||
trim_blocks = True,
|
||||
lstrip_blocks = True,
|
||||
undefined = jinja2.Undefined if allow_undef else jinja2.StrictUndefined,
|
||||
extensions = [
|
||||
"jinja2.ext.do", # statements with side-effects
|
||||
"jinja2.ext.loopcontrols", # break and continue
|
||||
],
|
||||
)
|
||||
env.globals.update(globals)
|
||||
env.filters.update({
|
||||
"defined": lambda x: _raise(f"{x._undefined_name!r} is not defined!") if jinja2.is_undefined(x) else x,
|
||||
"repr": repr,
|
||||
"to_json": json.dumps,
|
||||
"bool": lambda x: json.dumps(bool(x)),
|
||||
"int": lambda x: json.dumps(int(x)),
|
||||
"float": lambda x: json.dumps(float(x)),
|
||||
"str": lambda x: json.dumps(str(x)),
|
||||
})
|
||||
return env
|
||||
|
||||
def list_func_params(func: callable, exclude_list: set[str], defaults: dict={}) -> Iterable[tuple[str, Any, str]]:
|
||||
signature = inspect.signature(func)
|
||||
for i, (k, v) in enumerate(signature.parameters.items()):
|
||||
if not i and k in {"self", "cls"}:
|
||||
continue
|
||||
if k in exclude_list:
|
||||
continue
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
if v.kind is v.VAR_POSITIONAL or v.kind is v.VAR_KEYWORD:
|
||||
continue
|
||||
has_default = not defaults.get(k, v.default) is v.empty
|
||||
has_annotation = not v.annotation is v.empty
|
||||
allowed_literals = f"{{{', '.join(map(_yaml_encode_value, typing.get_args(v.annotation)))}}}" \
|
||||
if typing.get_origin(v.annotation) is Literal else None
|
||||
|
||||
assert has_annotation, f"param {k!r} has no type annotation"
|
||||
yield (
|
||||
k,
|
||||
defaults.get(k, v.default) if has_default else _UNDEFINED,
|
||||
f"in {allowed_literals}" if allowed_literals else typing._type_repr(v.annotation),
|
||||
)
|
||||
|
||||
@compose("\n".join)
|
||||
def make_jinja_template(
|
||||
network_cls: nn.Module,
|
||||
*,
|
||||
exclude_list: set[str] = set(),
|
||||
defaults: dict[str, Any]={},
|
||||
top_level: bool = True,
|
||||
commented: bool = False,
|
||||
name=None,
|
||||
comment: Optional[str] = None,
|
||||
special_encoders: dict[str, Callable[[Any], str]]={},
|
||||
) -> str:
|
||||
c = "#" if commented else ""
|
||||
if name is None:
|
||||
name = network_cls.__name__
|
||||
|
||||
if comment is not None:
|
||||
if "\n" in comment:
|
||||
raise ValueError("newline in jinja template comment is not allowed")
|
||||
|
||||
hparams = [*list_func_params(network_cls, exclude_list, defaults=defaults)]
|
||||
if not hparams:
|
||||
if top_level:
|
||||
yield f"{name}:"
|
||||
else:
|
||||
yield f" # {name}:"
|
||||
return
|
||||
|
||||
|
||||
encoded_hparams = [
|
||||
(key, _yaml_encode_value(value) if value is not _UNDEFINED else "", comment)
|
||||
if key not in special_encoders else
|
||||
(key, special_encoders[key](value) if value is not _UNDEFINED else "", comment)
|
||||
for key, value, comment in hparams
|
||||
]
|
||||
|
||||
ml_key, ml_value = elementwise_max(
|
||||
(
|
||||
len(key),
|
||||
len(value),
|
||||
)
|
||||
for key, value, comment in encoded_hparams
|
||||
)
|
||||
|
||||
if top_level:
|
||||
yield f"{name}:" if not comment else f"{name}: # {comment}"
|
||||
else:
|
||||
yield f" # {name}:" if not comment else f" # {name}: # {comment}"
|
||||
|
||||
for key, value, comment in encoded_hparams:
|
||||
if key in exclude_list:
|
||||
continue
|
||||
pad_key = ml_key - len(key)
|
||||
pad_value = ml_value - len(value)
|
||||
|
||||
yield f" {c}{key}{' '*pad_key} : {value}{' '*pad_value} # {comment}"
|
||||
|
||||
yield ""
|
||||
|
||||
# helpers:
|
||||
|
||||
def squash_newlines(data: str) -> str:
|
||||
return re.sub(r'\n\n\n+', '\n\n', data)
|
||||
Reference in New Issue
Block a user