diff --git a/laydi/dataset.py b/laydi/dataset.py index 1ab7f1b..00591eb 100644 --- a/laydi/dataset.py +++ b/laydi/dataset.py @@ -1,11 +1,80 @@ -from scipy import ndarray,atleast_2d,asarray,intersect1d,zeros,empty,sparse,\ -where +from scipy import ndarray, atleast_2d, asarray, intersect1d, zeros +from scipy import empty, sparse, where from scipy import sort as array_sort from itertools import izip import shelve import copy import re +class Universe(object): + def __init__(self, name): + self.name = name + self._ids = {} + + def register(self, dim): + """Increase reference count for identifiers in Dimension object dim""" + if dim.name != self.name: + return + for i in dim: + self._ids[i] = self._ids.get(i, 0) + 1 + + def unregister(self, dim): + """Update reference count for identifiers in Dimension object dim + Update reference count for identifiers in Dimension object dim, and remove all + identifiers with a reference count of 0, as they do not (by definition) exist + any longer. + """ + if dim.name != self.name: + return + for i in dim: + refcount = self._ids[i] + if refcount == 1: + self._ids.pop(i) + else: + self._ids[i] -= 1 + + def __str__(self): + return "%s: %i elements, %i references" % (self.name, len(self._ids), sum(self._ids.values())) + + def __contains__(self, element): + return self._ids.__contains__(element) + + def __len__(self): + return len(self._ids) + + def intersection(self, dim): + return set(self._ids).intersection(dim.idset) + + +class Dimension(object): + """A Dimension represents the set of identifiers an object has along an axis. + """ + def __init__(self, name, ids=[]): + self.name = name + self.idset = set(ids) + self.idlist = list(ids) + + def __getitem__(self, element): + return self.idlist[element] + + def __getslice__(self, start, end): + return self.idlist[start:end] + + def __contains__(self, element): + return self.idset.__contains__(element) + + def __str__(self): + return "%s: %s" % (self.name, str(self.idlist)) + + def __len__(self): + return len(self.idlist) + + def __iter__(self): + return iter(self.idlist) + + def intersection(self, dim): + return self.idset.intersection(dim.idset) + class Dataset(object): """The Dataset base class.