371 lines
14 KiB
Python
371 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
from abc import abstractmethod, ABCMeta
|
|
from collections import namedtuple
|
|
from pathlib import Path
|
|
import copy
|
|
import dataclasses
|
|
import functools
|
|
import h5py as h5
|
|
import hdf5plugin
|
|
import numpy as np
|
|
import operator
|
|
import os
|
|
import sys
|
|
import typing
|
|
|
|
__all__ = [
|
|
"DataclassMeta",
|
|
"Dataclass",
|
|
"H5Dataclass",
|
|
"H5Array",
|
|
"H5ArrayNoSlice",
|
|
]
|
|
|
|
T = typing.TypeVar("T")
|
|
NoneType = type(None)
|
|
PathLike = typing.Union[os.PathLike, str]
|
|
H5Array = typing._alias(np.ndarray, 0, inst=False, name="H5Array")
|
|
H5ArrayNoSlice = typing._alias(np.ndarray, 0, inst=False, name="H5ArrayNoSlice")
|
|
|
|
DataclassField = namedtuple("DataclassField", [
|
|
"name",
|
|
"type",
|
|
"is_optional",
|
|
"is_array",
|
|
"is_sliceable",
|
|
"is_prefix",
|
|
])
|
|
|
|
def strip_optional(val: type) -> type:
|
|
if typing.get_origin(val) is typing.Union:
|
|
union = set(typing.get_args(val))
|
|
if len(union - {NoneType}) == 1:
|
|
val, = union - {NoneType}
|
|
else:
|
|
raise TypeError(f"Non-'typing.Optional' 'typing.Union' is not supported: {typing._type_repr(val)!r}")
|
|
return val
|
|
|
|
def is_array(val, *, _inner=False):
|
|
"""
|
|
Hacky way to check if a value or type is an array.
|
|
The hack omits having to depend on large frameworks such as pytorch or pandas
|
|
"""
|
|
val = strip_optional(val)
|
|
if val is H5Array or val is H5ArrayNoSlice:
|
|
return True
|
|
|
|
if typing._type_repr(val) in (
|
|
"numpy.ndarray",
|
|
"torch.Tensor",
|
|
):
|
|
return True
|
|
if not _inner:
|
|
return is_array(type(val), _inner=True)
|
|
return False
|
|
|
|
def prod(numbers: typing.Iterable[T], initial: typing.Optional[T] = None) -> T:
|
|
if initial is not None:
|
|
return functools.reduce(operator.mul, numbers, initial)
|
|
else:
|
|
return functools.reduce(operator.mul, numbers)
|
|
|
|
class DataclassMeta(type):
|
|
def __new__(
|
|
mcls,
|
|
name : str,
|
|
bases : tuple[type, ...],
|
|
attrs : dict[str, typing.Any],
|
|
**kwargs,
|
|
):
|
|
cls = super().__new__(mcls, name, bases, attrs, **kwargs)
|
|
if sys.version_info[:2] >= (3, 10) and not hasattr(cls, "__slots__"):
|
|
cls = dataclasses.dataclass(slots=True)(cls)
|
|
else:
|
|
cls = dataclasses.dataclass(cls)
|
|
return cls
|
|
|
|
class DataclassABCMeta(DataclassMeta, ABCMeta):
|
|
pass
|
|
|
|
class Dataclass(metaclass=DataclassMeta):
|
|
def __getitem__(self, key: str) -> typing.Any:
|
|
if key in self.keys():
|
|
return getattr(self, key)
|
|
raise KeyError(key)
|
|
|
|
def __setitem__(self, key: str, value: typing.Any):
|
|
if key in self.keys():
|
|
return setattr(self, key, value)
|
|
raise KeyError(key)
|
|
|
|
def keys(self) -> typing.KeysView:
|
|
return self.as_dict().keys()
|
|
|
|
def values(self) -> typing.ValuesView:
|
|
return self.as_dict().values()
|
|
|
|
def items(self) -> typing.ItemsView:
|
|
return self.as_dict().items()
|
|
|
|
def as_dict(self, properties_to_include: set[str] = None, **kw) -> dict[str, typing.Any]:
|
|
out = dataclasses.asdict(self, **kw)
|
|
for name in (properties_to_include or []):
|
|
out[name] = getattr(self, name)
|
|
return out
|
|
|
|
def as_tuple(self, properties_to_include: list[str]) -> tuple:
|
|
out = dataclasses.astuple(self)
|
|
if not properties_to_include:
|
|
return out
|
|
else:
|
|
return (
|
|
*out,
|
|
*(getattr(self, name) for name in properties_to_include),
|
|
)
|
|
|
|
def copy(self: T, *, deep=True) -> T:
|
|
return (copy.deepcopy if deep else copy.copy)(self)
|
|
|
|
class H5Dataclass(Dataclass):
|
|
# settable with class params:
|
|
_prefix : str = dataclasses.field(init=False, repr=False, default="")
|
|
_n_pages : int = dataclasses.field(init=False, repr=False, default=10)
|
|
_require_all : bool = dataclasses.field(init=False, repr=False, default=False)
|
|
|
|
def __init_subclass__(cls,
|
|
prefix : typing.Optional[str] = None,
|
|
n_pages : typing.Optional[int] = None,
|
|
require_all : typing.Optional[bool] = None,
|
|
**kw,
|
|
):
|
|
super().__init_subclass__(**kw)
|
|
assert dataclasses.is_dataclass(cls)
|
|
if prefix is not None: cls._prefix = prefix
|
|
if n_pages is not None: cls._n_pages = n_pages
|
|
if require_all is not None: cls._require_all = require_all
|
|
|
|
@classmethod
|
|
def _get_fields(cls) -> typing.Iterable[DataclassField]:
|
|
for field in dataclasses.fields(cls):
|
|
if not field.init:
|
|
continue
|
|
assert field.name not in ("_prefix", "_n_pages", "_require_all"), (
|
|
f"{field.name!r} can not be in {cls.__qualname__}.__init__.\n"
|
|
"Set it with dataclasses.field(default=YOUR_VALUE, init=False, repr=False)"
|
|
)
|
|
if isinstance(field.type, str):
|
|
raise TypeError("Type hints are strings, perhaps avoid using `from __future__ import annotations`")
|
|
|
|
type_inner = strip_optional(field.type)
|
|
is_prefix = typing.get_origin(type_inner) is dict and typing.get_args(type_inner)[:1] == (str,)
|
|
field_type = typing.get_args(type_inner)[1] if is_prefix else field.type
|
|
if field.default is None or typing.get_origin(field.type) is typing.Union and NoneType in typing.get_args(field.type):
|
|
field_type = typing.Optional[field_type]
|
|
|
|
yield DataclassField(
|
|
name = field.name,
|
|
type = strip_optional(field_type),
|
|
is_optional = typing.get_origin(field_type) is typing.Union and NoneType in typing.get_args(field_type),
|
|
is_array = is_array(field_type),
|
|
is_sliceable = is_array(field_type) and strip_optional(field_type) is not H5ArrayNoSlice,
|
|
is_prefix = is_prefix,
|
|
)
|
|
|
|
@classmethod
|
|
def from_h5_file(cls : type[T],
|
|
fname : typing.Union[PathLike, str],
|
|
*,
|
|
page : typing.Optional[int] = None,
|
|
n_pages : typing.Optional[int] = None,
|
|
read_slice : slice = slice(None),
|
|
require_even_pages : bool = True,
|
|
) -> T:
|
|
if not isinstance(fname, Path):
|
|
fname = Path(fname)
|
|
if n_pages is None:
|
|
n_pages = cls._n_pages
|
|
if not fname.exists():
|
|
raise FileNotFoundError(str(fname))
|
|
if not h5.is_hdf5(fname):
|
|
raise TypeError(f"Not a HDF5 file: {str(fname)!r}")
|
|
|
|
# if this class has no fields, print a example class:
|
|
if not any(field.init for field in dataclasses.fields(cls)):
|
|
with h5.File(fname, "r") as f:
|
|
klen = max(map(len, f.keys()))
|
|
example_cls = f"\nclass {cls.__name__}(Dataclass, require_all=True):\n" + "\n".join(
|
|
f" {k.ljust(klen)} : "
|
|
+ (
|
|
"H5Array" if prod(v.shape, 1) > 1 else (
|
|
"float" if issubclass(v.dtype.type, np.floating) else (
|
|
"int" if issubclass(v.dtype.type, np.integer) else (
|
|
"bool" if issubclass(v.dtype.type, np.bool_) else (
|
|
"typing.Any"
|
|
))))).ljust(14 + 1)
|
|
+ f" #{repr(v).split(':', 1)[1].removesuffix('>')}"
|
|
for k, v in f.items()
|
|
)
|
|
raise NotImplementedError(f"{cls!r} has no fields!\nPerhaps try the following:{example_cls}")
|
|
|
|
fields_consumed = set()
|
|
|
|
def make_kwarg(
|
|
file : h5.File,
|
|
keys : typing.KeysView,
|
|
field : DataclassField,
|
|
) -> tuple[str, typing.Any]:
|
|
if field.is_optional:
|
|
if field.name not in keys:
|
|
return field.name, None
|
|
if field.is_sliceable:
|
|
if page is not None:
|
|
n_items = int(f[cls._prefix + field.name].shape[0])
|
|
page_len = n_items // n_pages
|
|
modulus = n_items % n_pages
|
|
if modulus: page_len += 1 # round up
|
|
if require_even_pages and modulus:
|
|
raise ValueError(f"Field {field.name!r} {tuple(f[cls._prefix + field.name].shape)} is not cleanly divisible into {n_pages} pages")
|
|
this_slice = slice(
|
|
start = page_len * page,
|
|
stop = page_len * (page+1),
|
|
step = read_slice.step, # inherit step
|
|
)
|
|
else:
|
|
this_slice = read_slice
|
|
else:
|
|
this_slice = slice(None) # read all
|
|
|
|
# array or scalar?
|
|
def read_dataset(var):
|
|
# https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
|
|
if field.is_array:
|
|
return var[this_slice]
|
|
if var.shape == (1,):
|
|
return var[0]
|
|
else:
|
|
return var[()]
|
|
|
|
if field.is_prefix:
|
|
fields_consumed.update(
|
|
key
|
|
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
|
)
|
|
return field.name, {
|
|
key.removeprefix(f"{cls._prefix}{field.name}_") : read_dataset(file[key])
|
|
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
|
|
}
|
|
else:
|
|
fields_consumed.add(cls._prefix + field.name)
|
|
return field.name, read_dataset(file[cls._prefix + field.name])
|
|
|
|
with h5.File(fname, "r") as f:
|
|
keys = f.keys()
|
|
init_dict = dict( make_kwarg(f, keys, i) for i in cls._get_fields() )
|
|
|
|
try:
|
|
out = cls(**init_dict)
|
|
except Exception as e:
|
|
class_attrs = set(field.name for field in dataclasses.fields(cls))
|
|
file_attr = set(init_dict.keys())
|
|
raise e.__class__(f"{e}. {class_attrs=}, {file_attr=}, diff={class_attrs.symmetric_difference(file_attr)}") from e
|
|
|
|
if cls._require_all:
|
|
fields_not_consumed = set(keys) - fields_consumed
|
|
if fields_not_consumed:
|
|
raise ValueError(f"Not all HDF5 fields consumed: {fields_not_consumed!r}")
|
|
|
|
return out
|
|
|
|
def to_h5_file(self,
|
|
fname : PathLike,
|
|
mkdir : bool = False,
|
|
):
|
|
if not isinstance(fname, Path):
|
|
fname = Path(fname)
|
|
if not fname.parent.is_dir():
|
|
if mkdir:
|
|
fname.parent.mkdir(parents=True)
|
|
else:
|
|
raise NotADirectoryError(fname.parent)
|
|
|
|
with h5.File(fname, "w") as f:
|
|
for field in type(self)._get_fields():
|
|
if field.is_optional and getattr(self, field.name) is None:
|
|
continue
|
|
value = getattr(self, field.name)
|
|
if field.is_array:
|
|
if any(type(i) is not np.ndarray for i in (value.values() if field.is_prefix else [value])):
|
|
raise TypeError(
|
|
"When dumping a H5Dataclass, make sure the array fields are "
|
|
f"numpy arrays (the type of {field.name!r} is {typing._type_repr(type(value))}).\n"
|
|
"Example: h5dataclass.map_arrays(torch.Tensor.numpy)"
|
|
)
|
|
else:
|
|
pass
|
|
|
|
def write_value(key: str, value: typing.Any):
|
|
if field.is_array:
|
|
f.create_dataset(key, data=value, **hdf5plugin.LZ4())
|
|
else:
|
|
f.create_dataset(key, data=value)
|
|
|
|
if field.is_prefix:
|
|
for k, v in value.items():
|
|
write_value(self._prefix + field.name + "_" + k, v)
|
|
else:
|
|
write_value(self._prefix + field.name, value)
|
|
|
|
def map_arrays(self: T, func: typing.Callable[[H5Array], H5Array], do_copy: bool = False) -> T:
|
|
if do_copy: # shallow
|
|
self = self.copy(deep=False)
|
|
for field in type(self)._get_fields():
|
|
if field.is_optional and getattr(self, field.name) is None:
|
|
continue
|
|
if field.is_prefix and field.is_array:
|
|
setattr(self, field.name, {
|
|
k : func(v)
|
|
for k, v in getattr(self, field.name).items()
|
|
})
|
|
elif field.is_array:
|
|
setattr(self, field.name, func(getattr(self, field.name)))
|
|
|
|
return self
|
|
|
|
def astype(self: T, t: type, do_copy: bool = False, convert_nonfloats: bool = False) -> T:
|
|
return self.map_arrays(lambda x: x.astype(t) if convert_nonfloats or not np.issubdtype(x.dtype, int) else x)
|
|
|
|
def copy(self: T, *, deep=True) -> T:
|
|
out = super().copy(deep=deep)
|
|
if not deep:
|
|
for field in type(self)._get_fields():
|
|
if field.is_prefix:
|
|
out[field.name] = copy.copy(field.name)
|
|
return out
|
|
|
|
@property
|
|
def shape(self) -> dict[str, tuple[int, ...]]:
|
|
return {
|
|
key: value.shape
|
|
for key, value in self.items()
|
|
if hasattr(value, "shape")
|
|
}
|
|
|
|
class TransformableDataclassMixin(metaclass=DataclassABCMeta):
|
|
|
|
@abstractmethod
|
|
def transform(self: T, mat4: np.ndarray, inplace=False) -> T:
|
|
...
|
|
|
|
def transform_to(self: T, name: str, inverse_name: str = None, *, inplace=False) -> T:
|
|
mtx = self.transforms[name]
|
|
out = self.transform(mtx, inplace=inplace)
|
|
out.transforms.pop(name) # consumed
|
|
|
|
inv = np.linalg.inv(mtx)
|
|
for key in list(out.transforms.keys()): # maintain the other transforms
|
|
out.transforms[key] = out.transforms[key] @ inv
|
|
if inverse_name is not None: # store inverse
|
|
out.transforms[inverse_name] = inv
|
|
|
|
return out
|