This commit is contained in:
2023-07-19 19:29:10 +02:00
parent b2a64395bd
commit 4f811cc4b0
60 changed files with 18209 additions and 1 deletions

205
ifield/utils/helpers.py Normal file
View File

@@ -0,0 +1,205 @@
from functools import wraps, reduce, partial
from itertools import zip_longest, groupby
from pathlib import Path
from typing import Iterable, TypeVar, Callable, Union, Optional, Mapping, Hashable
import collections
import operator
import re
Numeric = Union[int, float, complex]
T = TypeVar("T")
S = TypeVar("S")
# decorator
def compose(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
def wrapper(inner_func: Callable[..., S]):
@wraps(inner_func)
def wrapped(*a, **kw):
return outer_func(*outer_a, inner_func(*a, **kw), **outer_kw)
return wrapped
return wrapper
def compose_star(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
def wrapper(inner_func: Callable[..., S]):
@wraps(inner_func)
def wrapped(*a, **kw):
return outer_func(*outer_a, *inner_func(*a, **kw), **outer_kw)
return wrapped
return wrapper
# itertools
def elementwise_max(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
return reduce(lambda xs, ys: [*map(max, zip(xs, ys))], iterable)
def prod(numbers: Iterable[T], initial: Optional[T] = None) -> T:
if initial is not None:
return reduce(operator.mul, numbers, initial)
else:
return reduce(operator.mul, numbers)
def run_length_encode(data: Iterable[T]) -> Iterable[tuple[T, int]]:
return (
(x, len(y))
for x, y in groupby(data)
)
# text conversion
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
parts = (
part.lower()
for part in re.split(r'(?=[A-Z])', text)
if part
)
if join_abbreviations:
parts = list(parts)
if len(parts) > 1:
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
if len(a) == len(b) == 1:
parts[i] = parts[i] + parts.pop(i+1)
return sep.join(parts)
def snake_to_camel_case(text: str) -> str:
return "".join(
part.captialize()
for part in text.split("_")
if part
)
# textwrap
def columnize_dict(data: dict, n_columns=2, prefix="", sep=" ") -> str:
sub = (len(data) + 1) // n_columns
return reduce(partial(columnize, sep=sep),
(
columnize(
"\n".join([f"{'' if n else prefix}{i!r}" for i in data.keys() ][n*sub : (n+1)*sub]),
"\n".join([f": {i!r}," for i in data.values()][n*sub : (n+1)*sub]),
)
for n in range(n_columns)
)
)
def columnize(left: str, right: str, prefix="", sep=" ") -> str:
left = left .split("\n")
right = right.split("\n")
width = max(map(len, left)) if left else 0
return "\n".join(
f"{prefix}{a.ljust(width)}{sep}{b}"
if b else
f"{prefix}{a}"
for a, b in zip_longest(left, right, fillvalue="")
)
# pathlib
def make_relative(path: Union[Path, str], parent: Path = None) -> Path:
if isinstance(path, str):
path = Path(path)
if parent is None:
parent = Path.cwd()
try:
return path.relative_to(parent)
except ValueError:
pass
try:
return ".." / path.relative_to(parent.parent)
except ValueError:
pass
return path
# dictionaries
def update_recursive(target: dict, source: dict):
""" Update two config dictionaries recursively. """
for k, v in source.items():
if isinstance(v, dict):
if k not in target:
target[k] = type(target)()
update_recursive(target[k], v)
else:
target[k] = v
def map_tree(func: Callable[[T], S], val: Union[Mapping[Hashable, T], tuple[T, ...], list[T], T]) -> Union[Mapping[Hashable, S], tuple[S, ...], list[S], S]:
if isinstance(val, collections.abc.Mapping):
return {
k: map_tree(func, subval)
for k, subval in val.items()
}
elif isinstance(val, tuple):
return tuple(
map_tree(func, subval)
for subval in val
)
elif isinstance(val, list):
return [
map_tree(func, subval)
for subval in val
]
else:
return func(val)
def flatten_tree(val, *, sep=".", prefix=None):
if isinstance(val, collections.abc.Mapping):
return {
k: v
for subkey, subval in val.items()
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}{subkey}" if prefix else subkey).items()
}
elif isinstance(val, tuple) or isinstance(val, list):
return {
k: v
for index, subval in enumerate(val)
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}[{index}]" if prefix else f"[{index}]").items()
}
elif prefix:
return {prefix: val}
else:
return val
# conversions
def hex2tuple(data: str) -> tuple[int]:
data = data.removeprefix("#")
return (*(
int(data[i:i+2], 16)
for i in range(0, len(data), 2)
),)
# repr shims
class CustomRepr:
def __init__(self, repr_str: str):
self.repr_str = repr_str
def __str__(self):
return self.repr_str
def __repr__(self):
return self.repr_str
# Meta Params Module proxy
class MetaModuleProxy:
def __init__(self, module, params):
self._module = module
self._params = params
def __getattr__(self, name):
params = super().__getattribute__("_params")
if name in params:
return params[name]
else:
return getattr(super().__getattribute__("_module"), name)
def __setattr__(self, name, value):
if name not in ("_params", "_module"):
super().__getattribute__("_params")[name] = value
else:
super().__setattr__(name, value)