From 00dd3e8d9d84658c215bf930b482ff3cae5b10b4 Mon Sep 17 00:00:00 2001 From: flatberg Date: Wed, 14 Mar 2007 16:06:16 +0000 Subject: [PATCH] added validation on identifiers input --- fluents/dataset.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/fluents/dataset.py b/fluents/dataset.py index de3f566..9ee083f 100644 --- a/fluents/dataset.py +++ b/fluents/dataset.py @@ -42,20 +42,16 @@ class Dataset: self._name = name self._identifiers = identifiers self._type = 'n' - + if len(array.shape)==1: array = atleast_2d(asarray(array)) # vectors are column vectors if array.shape[0]==1: array = array.T self.shape = array.shape + if identifiers!=None: - identifier_shape = [len(i[1]) for i in identifiers] - if len(identifier_shape)!=len(self.shape): - raise ValueError, "Identifier list length must equal array dims" - for ni, na in zip(identifier_shape, self.shape): - if ni!=na: - raise ValueError, "identifier-array mismatch in %s: (idents: %s, array: %s)" %(self._name, ni, na) + self._validate_identifiers(identifiers) self._set_identifiers(identifiers, self._all_dims) else: self._identifiers = self._create_identifiers(self.shape, self._all_dims) @@ -157,8 +153,8 @@ class Dataset: given dim. Index (Indices) are the Identifiers position in a matrix in a given dim. """ - if indices!=None: - if len(indices)==0:# if empty list or empty array + if indices != None: + if len(indices) == 0:# if empty list or empty array return [] if indices != None: @@ -166,7 +162,7 @@ class Dataset: #indices = intersect1d(self.get_indices(dim),indices) ids = [self._map[dim].reverse[i] for i in indices] else: - if sorted==True: + if sorted == True: ids = [self._map[dim].reverse[i] for i in array_sort(self._map[dim].values())] else: ids = self._map[dim].keys() @@ -197,7 +193,18 @@ class Dataset: """ return copy.deepcopy(self) - + def _validate_identifiers(self, identifiers): + + for dim_name, ids in identifiers: + if len(set(ids)) != len(ids): + raise ValueError("Identifiers not unique in : %s" %dim_name) + identifier_shape = [len(i[1]) for i in identifiers] + if len(identifier_shape)!=len(self.shape): + raise ValueError("Identifier list length must equal array dims") + for ni, na in zip(identifier_shape, self.shape): + if ni != na: + raise ValueError, "Identifier-array mismatch: %s: (idents: %s, array: %s)" %(self._name, ni, na) + class CategoryDataset(Dataset): """The category dataset class. @@ -490,4 +497,3 @@ def read_ftsv(fd): return ds -