Add code
This commit is contained in:
196
ifield/datasets/common.py
Normal file
196
ifield/datasets/common.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from ..data.common.h5_dataclasses import H5Dataclass, PathLike
|
||||
from torch.utils.data import Dataset, IterableDataset
|
||||
from typing import Any, Iterable, Hashable, TypeVar, Iterator, Callable
|
||||
from functools import partial, lru_cache
|
||||
import inspect
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
T_H5 = TypeVar("T_H5", bound=H5Dataclass)
|
||||
|
||||
|
||||
class TransformableDatasetMixin:
|
||||
def __init_subclass__(cls):
|
||||
if getattr(cls, "_transformable_mixin_no_override_getitem", False):
|
||||
pass
|
||||
elif issubclass(cls, Dataset):
|
||||
if cls.__getitem__ is not cls._transformable_mixin_getitem_wrapper:
|
||||
cls._transformable_mixin_inner_getitem = cls.__getitem__
|
||||
cls.__getitem__ = cls._transformable_mixin_getitem_wrapper
|
||||
elif issubclass(cls, IterableDataset):
|
||||
if cls.__iter__ is not cls._transformable_mixin_iter_wrapper:
|
||||
cls._transformable_mixin_inner_iter = cls.__iter__
|
||||
cls.__iter__ = cls._transformable_mixin_iter_wrapper
|
||||
else:
|
||||
raise TypeError(f"{cls.__name__!r} is neither a Dataset nor a IterableDataset!")
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
super().__init__(*a, **kw)
|
||||
self._transforms = []
|
||||
|
||||
# works as a decorator
|
||||
def map(self: T, func: callable = None, /, args=[], **kw) -> T:
|
||||
def wrapper(func) -> T:
|
||||
if args or kw:
|
||||
func = partial(func, *args, **kw)
|
||||
self._transforms.append(func)
|
||||
return self
|
||||
|
||||
if func is None:
|
||||
return wrapper
|
||||
else:
|
||||
return wrapper(func)
|
||||
|
||||
|
||||
def _transformable_mixin_getitem_wrapper(self, index: int):
|
||||
if not self._transforms:
|
||||
out = self._transformable_mixin_inner_getitem(index) # (TransformableDatasetMixin, no transforms)
|
||||
else:
|
||||
out = self._transformable_mixin_inner_getitem(index) # (TransformableDatasetMixin, has transforms)
|
||||
for f in self._transforms:
|
||||
out = f(out) # (TransformableDatasetMixin)
|
||||
return out
|
||||
|
||||
def _transformable_mixin_iter_wrapper(self):
|
||||
if not self._transforms:
|
||||
out = self._transformable_mixin_inner_iter() # (TransformableDatasetMixin, no transforms)
|
||||
else:
|
||||
out = self._transformable_mixin_inner_iter() # (TransformableDatasetMixin, has transforms)
|
||||
for f in self._transforms:
|
||||
out = map(f, out) # (TransformableDatasetMixin)
|
||||
return out
|
||||
|
||||
|
||||
class TransformedDataset(Dataset, TransformableDatasetMixin):
|
||||
# used to wrap an another dataset
|
||||
def __init__(self, dataset: Dataset, transforms: Iterable[callable]):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
for i in transforms:
|
||||
self.map(i)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return self.dataset[index] # (TransformedDataset)
|
||||
|
||||
|
||||
class TransformExtendedDataset(Dataset, TransformableDatasetMixin):
|
||||
_transformable_mixin_no_override_getitem = True
|
||||
def __init__(self, dataset: Dataset):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset) * len(self._transforms)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
n = len(self._transforms)
|
||||
assert n > 0, f"{len(self._transforms) = }"
|
||||
|
||||
item = index // n
|
||||
transform = self._transforms[index % n]
|
||||
return transform(self.dataset[item])
|
||||
|
||||
|
||||
class CachedDataset(Dataset):
|
||||
# used to wrap an another dataset
|
||||
def __init__(self, dataset: Dataset, cache_size: int | None):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
if cache_size is not None and cache_size > 0:
|
||||
self.cached_getter = lru_cache(cache_size, self.dataset.__getitem__)
|
||||
else:
|
||||
self.cached_getter = self.dataset.__getitem__
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return self.cached_getter(index)
|
||||
|
||||
|
||||
class AutodecoderDataset(Dataset, TransformableDatasetMixin):
|
||||
def __init__(self,
|
||||
keys : Iterable[Hashable],
|
||||
dataset : Dataset,
|
||||
):
|
||||
super().__init__()
|
||||
self.ad_mapping = list(keys)
|
||||
self.dataset = dataset
|
||||
if len(self.ad_mapping) != len(dataset):
|
||||
raise ValueError(f"__len__ mismatch between keys and dataset: {len(self.ad_mapping)} != {len(dataset)}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index: int) -> tuple[Hashable, Any]:
|
||||
return self.ad_mapping[index], self.dataset[index] # (AutodecoderDataset)
|
||||
|
||||
def keys(self) -> list[Hashable]:
|
||||
return self.ad_mapping
|
||||
|
||||
def values(self) -> Iterator:
|
||||
return iter(self.dataset)
|
||||
|
||||
def items(self) -> Iterable[tuple[Hashable, Any]]:
|
||||
return zip(self.ad_mapping, self.dataset)
|
||||
|
||||
|
||||
class FunctionDataset(Dataset, TransformableDatasetMixin):
|
||||
def __init__(self,
|
||||
getter : Callable[[Hashable], T],
|
||||
keys : list[Hashable],
|
||||
cache_size : int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
if cache_size is not None and cache_size > 0:
|
||||
getter = lru_cache(cache_size)(getter)
|
||||
self.getter = getter
|
||||
self.keys = keys
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.keys)
|
||||
|
||||
def __getitem__(self, index: int) -> T:
|
||||
return self.getter(self.keys[index])
|
||||
|
||||
class H5Dataset(FunctionDataset):
|
||||
def __init__(self,
|
||||
h5_dataclass_cls : type[T_H5],
|
||||
fnames : list[PathLike],
|
||||
**kw,
|
||||
):
|
||||
super().__init__(
|
||||
getter = h5_dataclass_cls.from_h5_file,
|
||||
keys = fnames,
|
||||
**kw,
|
||||
)
|
||||
|
||||
class PaginatedH5Dataset(Dataset, TransformableDatasetMixin):
|
||||
def __init__(self,
|
||||
h5_dataclass_cls : type[T_H5],
|
||||
fnames : list[PathLike],
|
||||
n_pages : int = 10,
|
||||
require_even_pages : bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.h5_dataclass_cls = h5_dataclass_cls
|
||||
self.fnames = fnames
|
||||
self.n_pages = n_pages
|
||||
self.require_even_pages = require_even_pages
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.fnames) * self.n_pages
|
||||
|
||||
def __getitem__(self, index: int) -> T_H5:
|
||||
item = index // self.n_pages
|
||||
page = index % self.n_pages
|
||||
|
||||
return self.h5_dataclass_cls.from_h5_file( # (PaginatedH5Dataset)
|
||||
fname = self.fname[item],
|
||||
page = page,
|
||||
n_pages = self.n_pages,
|
||||
require_even_pages = self.require_even_pages,
|
||||
)
|
||||
Reference in New Issue
Block a user