Add code
This commit is contained in:
205
ifield/utils/helpers.py
Normal file
205
ifield/utils/helpers.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from functools import wraps, reduce, partial
|
||||
from itertools import zip_longest, groupby
|
||||
from pathlib import Path
|
||||
from typing import Iterable, TypeVar, Callable, Union, Optional, Mapping, Hashable
|
||||
import collections
|
||||
import operator
|
||||
import re
|
||||
|
||||
Numeric = Union[int, float, complex]
|
||||
T = TypeVar("T")
|
||||
S = TypeVar("S")
|
||||
|
||||
# decorator
|
||||
def compose(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
|
||||
def wrapper(inner_func: Callable[..., S]):
|
||||
@wraps(inner_func)
|
||||
def wrapped(*a, **kw):
|
||||
return outer_func(*outer_a, inner_func(*a, **kw), **outer_kw)
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
def compose_star(outer_func: Callable[[..., S], T], *outer_a, **outer_kw) -> Callable[..., T]:
|
||||
def wrapper(inner_func: Callable[..., S]):
|
||||
@wraps(inner_func)
|
||||
def wrapped(*a, **kw):
|
||||
return outer_func(*outer_a, *inner_func(*a, **kw), **outer_kw)
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
|
||||
# itertools
|
||||
|
||||
def elementwise_max(iterable: Iterable[Iterable[T]]) -> Iterable[T]:
|
||||
return reduce(lambda xs, ys: [*map(max, zip(xs, ys))], iterable)
|
||||
|
||||
def prod(numbers: Iterable[T], initial: Optional[T] = None) -> T:
|
||||
if initial is not None:
|
||||
return reduce(operator.mul, numbers, initial)
|
||||
else:
|
||||
return reduce(operator.mul, numbers)
|
||||
|
||||
def run_length_encode(data: Iterable[T]) -> Iterable[tuple[T, int]]:
|
||||
return (
|
||||
(x, len(y))
|
||||
for x, y in groupby(data)
|
||||
)
|
||||
|
||||
|
||||
# text conversion
|
||||
|
||||
def camel_to_snake_case(text: str, sep: str = "_", join_abbreviations: bool = False) -> str:
|
||||
parts = (
|
||||
part.lower()
|
||||
for part in re.split(r'(?=[A-Z])', text)
|
||||
if part
|
||||
)
|
||||
if join_abbreviations:
|
||||
parts = list(parts)
|
||||
if len(parts) > 1:
|
||||
for i, (a, b) in list(enumerate(zip(parts[:-1], parts[1:])))[::-1]:
|
||||
if len(a) == len(b) == 1:
|
||||
parts[i] = parts[i] + parts.pop(i+1)
|
||||
return sep.join(parts)
|
||||
|
||||
def snake_to_camel_case(text: str) -> str:
|
||||
return "".join(
|
||||
part.captialize()
|
||||
for part in text.split("_")
|
||||
if part
|
||||
)
|
||||
|
||||
|
||||
# textwrap
|
||||
|
||||
def columnize_dict(data: dict, n_columns=2, prefix="", sep=" ") -> str:
|
||||
sub = (len(data) + 1) // n_columns
|
||||
return reduce(partial(columnize, sep=sep),
|
||||
(
|
||||
columnize(
|
||||
"\n".join([f"{'' if n else prefix}{i!r}" for i in data.keys() ][n*sub : (n+1)*sub]),
|
||||
"\n".join([f": {i!r}," for i in data.values()][n*sub : (n+1)*sub]),
|
||||
)
|
||||
for n in range(n_columns)
|
||||
)
|
||||
)
|
||||
|
||||
def columnize(left: str, right: str, prefix="", sep=" ") -> str:
|
||||
left = left .split("\n")
|
||||
right = right.split("\n")
|
||||
width = max(map(len, left)) if left else 0
|
||||
return "\n".join(
|
||||
f"{prefix}{a.ljust(width)}{sep}{b}"
|
||||
if b else
|
||||
f"{prefix}{a}"
|
||||
for a, b in zip_longest(left, right, fillvalue="")
|
||||
)
|
||||
|
||||
|
||||
# pathlib
|
||||
|
||||
def make_relative(path: Union[Path, str], parent: Path = None) -> Path:
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
if parent is None:
|
||||
parent = Path.cwd()
|
||||
try:
|
||||
return path.relative_to(parent)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return ".." / path.relative_to(parent.parent)
|
||||
except ValueError:
|
||||
pass
|
||||
return path
|
||||
|
||||
|
||||
# dictionaries
|
||||
|
||||
def update_recursive(target: dict, source: dict):
|
||||
""" Update two config dictionaries recursively. """
|
||||
for k, v in source.items():
|
||||
if isinstance(v, dict):
|
||||
if k not in target:
|
||||
target[k] = type(target)()
|
||||
update_recursive(target[k], v)
|
||||
else:
|
||||
target[k] = v
|
||||
|
||||
def map_tree(func: Callable[[T], S], val: Union[Mapping[Hashable, T], tuple[T, ...], list[T], T]) -> Union[Mapping[Hashable, S], tuple[S, ...], list[S], S]:
|
||||
if isinstance(val, collections.abc.Mapping):
|
||||
return {
|
||||
k: map_tree(func, subval)
|
||||
for k, subval in val.items()
|
||||
}
|
||||
elif isinstance(val, tuple):
|
||||
return tuple(
|
||||
map_tree(func, subval)
|
||||
for subval in val
|
||||
)
|
||||
elif isinstance(val, list):
|
||||
return [
|
||||
map_tree(func, subval)
|
||||
for subval in val
|
||||
]
|
||||
else:
|
||||
return func(val)
|
||||
|
||||
def flatten_tree(val, *, sep=".", prefix=None):
|
||||
if isinstance(val, collections.abc.Mapping):
|
||||
return {
|
||||
k: v
|
||||
for subkey, subval in val.items()
|
||||
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}{subkey}" if prefix else subkey).items()
|
||||
}
|
||||
elif isinstance(val, tuple) or isinstance(val, list):
|
||||
return {
|
||||
k: v
|
||||
for index, subval in enumerate(val)
|
||||
for k, v in flatten_tree(subval, sep=sep, prefix=f"{prefix}{sep}[{index}]" if prefix else f"[{index}]").items()
|
||||
}
|
||||
elif prefix:
|
||||
return {prefix: val}
|
||||
else:
|
||||
return val
|
||||
|
||||
# conversions
|
||||
|
||||
def hex2tuple(data: str) -> tuple[int]:
|
||||
data = data.removeprefix("#")
|
||||
return (*(
|
||||
int(data[i:i+2], 16)
|
||||
for i in range(0, len(data), 2)
|
||||
),)
|
||||
|
||||
|
||||
# repr shims
|
||||
|
||||
class CustomRepr:
|
||||
def __init__(self, repr_str: str):
|
||||
self.repr_str = repr_str
|
||||
def __str__(self):
|
||||
return self.repr_str
|
||||
def __repr__(self):
|
||||
return self.repr_str
|
||||
|
||||
|
||||
# Meta Params Module proxy
|
||||
|
||||
class MetaModuleProxy:
|
||||
def __init__(self, module, params):
|
||||
self._module = module
|
||||
self._params = params
|
||||
|
||||
def __getattr__(self, name):
|
||||
params = super().__getattribute__("_params")
|
||||
if name in params:
|
||||
return params[name]
|
||||
else:
|
||||
return getattr(super().__getattribute__("_module"), name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name not in ("_params", "_module"):
|
||||
super().__getattribute__("_params")[name] = value
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
Reference in New Issue
Block a user