Files
marf/ifield/data/common/h5_dataclasses.py
2025-01-09 15:43:11 +01:00

371 lines
14 KiB
Python

#!/usr/bin/env python3
from abc import abstractmethod, ABCMeta
from collections import namedtuple
from pathlib import Path
import copy
import dataclasses
import functools
import h5py as h5
import hdf5plugin
import numpy as np
import operator
import os
import sys
import typing
__all__ = [
"DataclassMeta",
"Dataclass",
"H5Dataclass",
"H5Array",
"H5ArrayNoSlice",
]
T = typing.TypeVar("T")
NoneType = type(None)
PathLike = typing.Union[os.PathLike, str]
H5Array = typing._alias(np.ndarray, 0, inst=False, name="H5Array")
H5ArrayNoSlice = typing._alias(np.ndarray, 0, inst=False, name="H5ArrayNoSlice")
DataclassField = namedtuple("DataclassField", [
"name",
"type",
"is_optional",
"is_array",
"is_sliceable",
"is_prefix",
])
def strip_optional(val: type) -> type:
if typing.get_origin(val) is typing.Union:
union = set(typing.get_args(val))
if len(union - {NoneType}) == 1:
val, = union - {NoneType}
else:
raise TypeError(f"Non-'typing.Optional' 'typing.Union' is not supported: {typing._type_repr(val)!r}")
return val
def is_array(val, *, _inner=False):
"""
Hacky way to check if a value or type is an array.
The hack omits having to depend on large frameworks such as pytorch or pandas
"""
val = strip_optional(val)
if val is H5Array or val is H5ArrayNoSlice:
return True
if typing._type_repr(val) in (
"numpy.ndarray",
"torch.Tensor",
):
return True
if not _inner:
return is_array(type(val), _inner=True)
return False
def prod(numbers: typing.Iterable[T], initial: typing.Optional[T] = None) -> T:
if initial is not None:
return functools.reduce(operator.mul, numbers, initial)
else:
return functools.reduce(operator.mul, numbers)
class DataclassMeta(type):
def __new__(
mcls,
name : str,
bases : tuple[type, ...],
attrs : dict[str, typing.Any],
**kwargs,
):
cls = super().__new__(mcls, name, bases, attrs, **kwargs)
if sys.version_info[:2] >= (3, 10) and not hasattr(cls, "__slots__"):
cls = dataclasses.dataclass(slots=True)(cls)
else:
cls = dataclasses.dataclass(cls)
return cls
class DataclassABCMeta(DataclassMeta, ABCMeta):
pass
class Dataclass(metaclass=DataclassMeta):
def __getitem__(self, key: str) -> typing.Any:
if key in self.keys():
return getattr(self, key)
raise KeyError(key)
def __setitem__(self, key: str, value: typing.Any):
if key in self.keys():
return setattr(self, key, value)
raise KeyError(key)
def keys(self) -> typing.KeysView:
return self.as_dict().keys()
def values(self) -> typing.ValuesView:
return self.as_dict().values()
def items(self) -> typing.ItemsView:
return self.as_dict().items()
def as_dict(self, properties_to_include: set[str] = None, **kw) -> dict[str, typing.Any]:
out = dataclasses.asdict(self, **kw)
for name in (properties_to_include or []):
out[name] = getattr(self, name)
return out
def as_tuple(self, properties_to_include: list[str]) -> tuple:
out = dataclasses.astuple(self)
if not properties_to_include:
return out
else:
return (
*out,
*(getattr(self, name) for name in properties_to_include),
)
def copy(self: T, *, deep=True) -> T:
return (copy.deepcopy if deep else copy.copy)(self)
class H5Dataclass(Dataclass):
# settable with class params:
_prefix : str = dataclasses.field(init=False, repr=False, default="")
_n_pages : int = dataclasses.field(init=False, repr=False, default=10)
_require_all : bool = dataclasses.field(init=False, repr=False, default=False)
def __init_subclass__(cls,
prefix : typing.Optional[str] = None,
n_pages : typing.Optional[int] = None,
require_all : typing.Optional[bool] = None,
**kw,
):
super().__init_subclass__(**kw)
assert dataclasses.is_dataclass(cls)
if prefix is not None: cls._prefix = prefix
if n_pages is not None: cls._n_pages = n_pages
if require_all is not None: cls._require_all = require_all
@classmethod
def _get_fields(cls) -> typing.Iterable[DataclassField]:
for field in dataclasses.fields(cls):
if not field.init:
continue
assert field.name not in ("_prefix", "_n_pages", "_require_all"), (
f"{field.name!r} can not be in {cls.__qualname__}.__init__.\n"
"Set it with dataclasses.field(default=YOUR_VALUE, init=False, repr=False)"
)
if isinstance(field.type, str):
raise TypeError("Type hints are strings, perhaps avoid using `from __future__ import annotations`")
type_inner = strip_optional(field.type)
is_prefix = typing.get_origin(type_inner) is dict and typing.get_args(type_inner)[:1] == (str,)
field_type = typing.get_args(type_inner)[1] if is_prefix else field.type
if field.default is None or typing.get_origin(field.type) is typing.Union and NoneType in typing.get_args(field.type):
field_type = typing.Optional[field_type]
yield DataclassField(
name = field.name,
type = strip_optional(field_type),
is_optional = typing.get_origin(field_type) is typing.Union and NoneType in typing.get_args(field_type),
is_array = is_array(field_type),
is_sliceable = is_array(field_type) and strip_optional(field_type) is not H5ArrayNoSlice,
is_prefix = is_prefix,
)
@classmethod
def from_h5_file(cls : type[T],
fname : typing.Union[PathLike, str],
*,
page : typing.Optional[int] = None,
n_pages : typing.Optional[int] = None,
read_slice : slice = slice(None),
require_even_pages : bool = True,
) -> T:
if not isinstance(fname, Path):
fname = Path(fname)
if n_pages is None:
n_pages = cls._n_pages
if not fname.exists():
raise FileNotFoundError(str(fname))
if not h5.is_hdf5(fname):
raise TypeError(f"Not a HDF5 file: {str(fname)!r}")
# if this class has no fields, print a example class:
if not any(field.init for field in dataclasses.fields(cls)):
with h5.File(fname, "r") as f:
klen = max(map(len, f.keys()))
example_cls = f"\nclass {cls.__name__}(Dataclass, require_all=True):\n" + "\n".join(
f" {k.ljust(klen)} : "
+ (
"H5Array" if prod(v.shape, 1) > 1 else (
"float" if issubclass(v.dtype.type, np.floating) else (
"int" if issubclass(v.dtype.type, np.integer) else (
"bool" if issubclass(v.dtype.type, np.bool_) else (
"typing.Any"
))))).ljust(14 + 1)
+ f" #{repr(v).split(':', 1)[1].removesuffix('>')}"
for k, v in f.items()
)
raise NotImplementedError(f"{cls!r} has no fields!\nPerhaps try the following:{example_cls}")
fields_consumed = set()
def make_kwarg(
file : h5.File,
keys : typing.KeysView,
field : DataclassField,
) -> tuple[str, typing.Any]:
if field.is_optional:
if field.name not in keys:
return field.name, None
if field.is_sliceable:
if page is not None:
n_items = int(f[cls._prefix + field.name].shape[0])
page_len = n_items // n_pages
modulus = n_items % n_pages
if modulus: page_len += 1 # round up
if require_even_pages and modulus:
raise ValueError(f"Field {field.name!r} {tuple(f[cls._prefix + field.name].shape)} is not cleanly divisible into {n_pages} pages")
this_slice = slice(
start = page_len * page,
stop = page_len * (page+1),
step = read_slice.step, # inherit step
)
else:
this_slice = read_slice
else:
this_slice = slice(None) # read all
# array or scalar?
def read_dataset(var):
# https://docs.h5py.org/en/stable/high/dataset.html#reading-writing-data
if field.is_array:
return var[this_slice]
if var.shape == (1,):
return var[0]
else:
return var[()]
if field.is_prefix:
fields_consumed.update(
key
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
)
return field.name, {
key.removeprefix(f"{cls._prefix}{field.name}_") : read_dataset(file[key])
for key in keys if key.startswith(f"{cls._prefix}{field.name}_")
}
else:
fields_consumed.add(cls._prefix + field.name)
return field.name, read_dataset(file[cls._prefix + field.name])
with h5.File(fname, "r") as f:
keys = f.keys()
init_dict = dict( make_kwarg(f, keys, i) for i in cls._get_fields() )
try:
out = cls(**init_dict)
except Exception as e:
class_attrs = set(field.name for field in dataclasses.fields(cls))
file_attr = set(init_dict.keys())
raise e.__class__(f"{e}. {class_attrs=}, {file_attr=}, diff={class_attrs.symmetric_difference(file_attr)}") from e
if cls._require_all:
fields_not_consumed = set(keys) - fields_consumed
if fields_not_consumed:
raise ValueError(f"Not all HDF5 fields consumed: {fields_not_consumed!r}")
return out
def to_h5_file(self,
fname : PathLike,
mkdir : bool = False,
):
if not isinstance(fname, Path):
fname = Path(fname)
if not fname.parent.is_dir():
if mkdir:
fname.parent.mkdir(parents=True)
else:
raise NotADirectoryError(fname.parent)
with h5.File(fname, "w") as f:
for field in type(self)._get_fields():
if field.is_optional and getattr(self, field.name) is None:
continue
value = getattr(self, field.name)
if field.is_array:
if any(type(i) is not np.ndarray for i in (value.values() if field.is_prefix else [value])):
raise TypeError(
"When dumping a H5Dataclass, make sure the array fields are "
f"numpy arrays (the type of {field.name!r} is {typing._type_repr(type(value))}).\n"
"Example: h5dataclass.map_arrays(torch.Tensor.numpy)"
)
else:
pass
def write_value(key: str, value: typing.Any):
if field.is_array:
f.create_dataset(key, data=value, **hdf5plugin.LZ4())
else:
f.create_dataset(key, data=value)
if field.is_prefix:
for k, v in value.items():
write_value(self._prefix + field.name + "_" + k, v)
else:
write_value(self._prefix + field.name, value)
def map_arrays(self: T, func: typing.Callable[[H5Array], H5Array], do_copy: bool = False) -> T:
if do_copy: # shallow
self = self.copy(deep=False)
for field in type(self)._get_fields():
if field.is_optional and getattr(self, field.name) is None:
continue
if field.is_prefix and field.is_array:
setattr(self, field.name, {
k : func(v)
for k, v in getattr(self, field.name).items()
})
elif field.is_array:
setattr(self, field.name, func(getattr(self, field.name)))
return self
def astype(self: T, t: type, do_copy: bool = False, convert_nonfloats: bool = False) -> T:
return self.map_arrays(lambda x: x.astype(t) if convert_nonfloats or not np.issubdtype(x.dtype, int) else x)
def copy(self: T, *, deep=True) -> T:
out = super().copy(deep=deep)
if not deep:
for field in type(self)._get_fields():
if field.is_prefix:
out[field.name] = copy.copy(field.name)
return out
@property
def shape(self) -> dict[str, tuple[int, ...]]:
return {
key: value.shape
for key, value in self.items()
if hasattr(value, "shape")
}
class TransformableDataclassMixin(metaclass=DataclassABCMeta):
@abstractmethod
def transform(self: T, mat4: np.ndarray, inplace=False) -> T:
...
def transform_to(self: T, name: str, inverse_name: str = None, *, inplace=False) -> T:
mtx = self.transforms[name]
out = self.transform(mtx, inplace=inplace)
out.transforms.pop(name) # consumed
inv = np.linalg.inv(mtx)
for key in list(out.transforms.keys()): # maintain the other transforms
out.transforms[key] = out.transforms[key] @ inv
if inverse_name is not None: # store inverse
out.transforms[inverse_name] = inv
return out