206 lines
5.8 KiB
Python
206 lines
5.8 KiB
Python
from functools import wraps, reduce, partial
|
|
from itertools import zip_longest, groupby
|
|
from pathlib import Path
|
|
from typing import Iterable, TypeVar, Callable, Union, Optional, Mapping, Hashable
|
|
import collections
|
|
import operator
|
|
import re
|
|
|
|
Numeric = Union[int, float, complex]
|
|
T = TypeVar("T")
|
|
S = TypeVar("S")
|
|
|
|
# decorator
|
|
def compose(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
|
|
def wrapper(inner_func: Callable[..., S]):
|
|
@wraps(inner_func)
|
|
def wrapped(*a, **kw):
|
|
return outer_func(*outer_a, inner_func(*a, **kw), **outer_kw)
|
|
return wrapped
|
|
return wrapper
|
|
|
|
def compose_star(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
|
|
def wrapper(inner_func: Callable[..., S]):
|
|
@wraps(inner_func)
|
|
def wrapped(*a, **kw):
|
|
return outer_func(*outer_a, *inner_func(*a, **kw), **outer_kw)
|
|
return wrapped
|
|
return wrapper
|
|
|
|
|
|
# itertools
|
|
|
|
def elementwise_max(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
|
|
return reduce(lambda xs, ys: [*map(max, zip(xs, ys))], iterable)
|
|
|
|
def prod(numbers: Iterable[T], initial: Optional[T] = None) -> T:
|
|
if initial is not None:
|
|
return reduce(operator.mul, numbers, initial)
|
|
else:
|
|
return reduce(operator.mul, numbers)
|
|
|
|
def run_length_encode(data: Iterable[T]) -> Iterable[tuple[T, int]]:
|
|
return (
|
|
(x, len(y))
|
|
for x, y in groupby(data)
|
|
)
|
|
|
|
|
|
# text conversion
|
|
|
|
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
|
|
parts = (
|
|
part.lower()
|
|
for part in re.split(r'(?=[A-Z])', text)
|
|
if part
|
|
)
|
|
if join_abbreviations:
|
|
parts = list(parts)
|
|
if len(parts) > 1:
|
|
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
|
|
if len(a) == len(b) == 1:
|
|
parts[i] = parts[i] + parts.pop(i+1)
|
|
return sep.join(parts)
|
|
|
|
def snake_to_camel_case(text: str) -> str:
|
|
return "".join(
|
|
part.captialize()
|
|
for part in text.split("_")
|
|
if part
|
|
)
|
|
|
|
|
|
# textwrap
|
|
|
|
def columnize_dict(data: dict, n_columns=2, prefix="", sep=" ") -> str:
|
|
sub = (len(data) + 1) // n_columns
|
|
return reduce(partial(columnize, sep=sep),
|
|
(
|
|
columnize(
|
|
"\n".join([f"{'' if n else prefix}{i!r}" for i in data.keys() ][n*sub : (n+1)*sub]),
|
|
"\n".join([f": {i!r}," for i in data.values()][n*sub : (n+1)*sub]),
|
|
)
|
|
for n in range(n_columns)
|
|
)
|
|
)
|
|
|
|
def columnize(left: str, right: str, prefix="", sep=" ") -> str:
|
|
left = left .split("\n")
|
|
right = right.split("\n")
|
|
width = max(map(len, left)) if left else 0
|
|
return "\n".join(
|
|
f"{prefix}{a.ljust(width)}{sep}{b}"
|
|
if b else
|
|
f"{prefix}{a}"
|
|
for a, b in zip_longest(left, right, fillvalue="")
|
|
)
|
|
|
|
|
|
# pathlib
|
|
|
|
def make_relative(path: Union[Path, str], parent: Path = None) -> Path:
|
|
if isinstance(path, str):
|
|
path = Path(path)
|
|
if parent is None:
|
|
parent = Path.cwd()
|
|
try:
|
|
return path.relative_to(parent)
|
|
except ValueError:
|
|
pass
|
|
try:
|
|
return ".." / path.relative_to(parent.parent)
|
|
except ValueError:
|
|
pass
|
|
return path
|
|
|
|
|
|
# dictionaries
|
|
|
|
def update_recursive(target: dict, source: dict):
|
|
""" Update two config dictionaries recursively. """
|
|
for k, v in source.items():
|
|
if isinstance(v, dict):
|
|
if k not in target:
|
|
target[k] = type(target)()
|
|
update_recursive(target[k], v)
|
|
else:
|
|
target[k] = v
|
|
|
|
def map_tree(func: Callable[[T], S], val: Union[Mapping[Hashable, T], tuple[T, ...], list[T], T]) -> Union[Mapping[Hashable, S], tuple[S, ...], list[S], S]:
|
|
if isinstance(val, collections.abc.Mapping):
|
|
return {
|
|
k: map_tree(func, subval)
|
|
for k, subval in val.items()
|
|
}
|
|
elif isinstance(val, tuple):
|
|
return tuple(
|
|
map_tree(func, subval)
|
|
for subval in val
|
|
)
|
|
elif isinstance(val, list):
|
|
return [
|
|
map_tree(func, subval)
|
|
for subval in val
|
|
]
|
|
else:
|
|
return func(val)
|
|
|
|
def flatten_tree(val, *, sep=".", prefix=None):
|
|
if isinstance(val, collections.abc.Mapping):
|
|
return {
|
|
k: v
|
|
for subkey, subval in val.items()
|
|
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}{subkey}" if prefix else subkey).items()
|
|
}
|
|
elif isinstance(val, tuple) or isinstance(val, list):
|
|
return {
|
|
k: v
|
|
for index, subval in enumerate(val)
|
|
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}[{index}]" if prefix else f"[{index}]").items()
|
|
}
|
|
elif prefix:
|
|
return {prefix: val}
|
|
else:
|
|
return val
|
|
|
|
# conversions
|
|
|
|
def hex2tuple(data: str) -> tuple[int]:
|
|
data = data.removeprefix("#")
|
|
return (*(
|
|
int(data[i:i+2], 16)
|
|
for i in range(0, len(data), 2)
|
|
),)
|
|
|
|
|
|
# repr shims
|
|
|
|
class CustomRepr:
|
|
def __init__(self, repr_str: str):
|
|
self.repr_str = repr_str
|
|
def __str__(self):
|
|
return self.repr_str
|
|
def __repr__(self):
|
|
return self.repr_str
|
|
|
|
|
|
# Meta Params Module proxy
|
|
|
|
class MetaModuleProxy:
|
|
def __init__(self, module, params):
|
|
self._module = module
|
|
self._params = params
|
|
|
|
def __getattr__(self, name):
|
|
params = super().__getattribute__("_params")
|
|
if name in params:
|
|
return params[name]
|
|
else:
|
|
return getattr(super().__getattribute__("_module"), name)
|
|
|
|
def __setattr__(self, name, value):
|
|
if name not in ("_params", "_module"):
|
|
super().__getattribute__("_params")[name] = value
|
|
else:
|
|
super().__setattr__(name, value)
|